Spaces:
Sleeping
Sleeping
nafisehNik
commited on
Commit
•
74688de
1
Parent(s):
9e81616
space created
Browse files- .streamlit/config.toml +6 -0
- app.py +94 -0
- assets/logo.svg +1 -0
- requirements.txt +1 -0
.streamlit/config.toml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
primaryColor="#FF8000"
|
3 |
+
#backgroundColor="#FFFFFF"
|
4 |
+
#secondaryBackgroundColor="#F0F2F6"
|
5 |
+
#textColor="#262730"
|
6 |
+
#font="sans serif"
|
app.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The GIRT Authors.
|
3 |
+
# Lint as: python3
|
4 |
+
|
5 |
+
|
6 |
+
# This space is built based on AMR-KELEG/ALDi and cis-lmu/GlotLID space.
|
7 |
+
# GIRT Space
|
8 |
+
|
9 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
10 |
+
import streamlit as st
|
11 |
+
import base64
|
12 |
+
|
13 |
+
|
14 |
+
@st.cache_data
|
15 |
+
def render_svg(svg):
|
16 |
+
"""Renders the given svg string."""
|
17 |
+
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
|
18 |
+
html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}", width="40%"/> </p>'
|
19 |
+
c = st.container()
|
20 |
+
c.write(html, unsafe_allow_html=True)
|
21 |
+
|
22 |
+
|
23 |
+
@st.cache_resource
|
24 |
+
def load_model(model_name):
|
25 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
26 |
+
return model
|
27 |
+
|
28 |
+
@st.cache_resource
|
29 |
+
def load_tokenizer(model_name):
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
31 |
+
return tokenizer
|
32 |
+
|
33 |
+
with st.spinner(text="Please wait while the model is loading...."):
|
34 |
+
|
35 |
+
model = load_model('nafisehNik/girt-t5-base')
|
36 |
+
tokenizer = load_tokenizer('nafisehNik/girt-t5-base')
|
37 |
+
|
38 |
+
|
39 |
+
def compute(sample, num_beams, length_penalty, early_stopping, max_length, min_length):
|
40 |
+
|
41 |
+
inputs = tokenizer(sample, return_tensors="pt").to('cpu')
|
42 |
+
|
43 |
+
outputs = model.generate(
|
44 |
+
**inputs,
|
45 |
+
num_beams=num_beams,
|
46 |
+
num_return_sequences=1,
|
47 |
+
length_penalty=length_penalty,
|
48 |
+
no_repeat_ngram_size=2,
|
49 |
+
early_stopping=early_stopping,
|
50 |
+
max_length=max_length,
|
51 |
+
min_length=min_length).to('cpu')
|
52 |
+
|
53 |
+
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False)
|
54 |
+
generated_text = generated_texts[0]
|
55 |
+
|
56 |
+
replace_dict = {
|
57 |
+
'\n ': '\n',
|
58 |
+
'</s>': '',
|
59 |
+
'<pad> ': '',
|
60 |
+
'<pad>': '',
|
61 |
+
'<unk>': ''
|
62 |
+
}
|
63 |
+
|
64 |
+
postprocess_text = generated_text
|
65 |
+
for key, value in replace_dict.items():
|
66 |
+
postprocess_text = postprocess_text.replace(key, value)
|
67 |
+
|
68 |
+
|
69 |
+
return postprocess_text
|
70 |
+
|
71 |
+
|
72 |
+
st.markdown("[![Duplicate Space](https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14)](https://huggingface.co/spaces/nafisehNik/girt-space?duplicate=true)")
|
73 |
+
|
74 |
+
render_svg(open("assets/logo.svg").read())
|
75 |
+
|
76 |
+
tab1, tab2 = st.tabs(["Design GitHub Issue Template", "Manual Prompt"])
|
77 |
+
|
78 |
+
with tab1:
|
79 |
+
pass
|
80 |
+
|
81 |
+
with tab2:
|
82 |
+
|
83 |
+
sent = st.text_input(
|
84 |
+
"Sentence:", placeholder="Enter a prompt.", on_change=None
|
85 |
+
)
|
86 |
+
|
87 |
+
# TODO: Check if this is needed!
|
88 |
+
clicked = st.button("Submit")
|
89 |
+
|
90 |
+
if sent:
|
91 |
+
res = compute(sent, num_beams=2, length_penalty=1.0, early_stopping=True, max_length=300, min_length=20)
|
92 |
+
st.code(res, language="python")
|
93 |
+
|
94 |
+
|
assets/logo.svg
ADDED
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
transformers>=4.35.0,<4.45.0
|