nafisehNik commited on
Commit
74688de
1 Parent(s): 9e81616

space created

Browse files
Files changed (4) hide show
  1. .streamlit/config.toml +6 -0
  2. app.py +94 -0
  3. assets/logo.svg +1 -0
  4. requirements.txt +1 -0
.streamlit/config.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor="#FF8000"
3
+ #backgroundColor="#FFFFFF"
4
+ #secondaryBackgroundColor="#F0F2F6"
5
+ #textColor="#262730"
6
+ #font="sans serif"
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The GIRT Authors.
3
+ # Lint as: python3
4
+
5
+
6
+ # This space is built based on AMR-KELEG/ALDi and cis-lmu/GlotLID space.
7
+ # GIRT Space
8
+
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
+ import streamlit as st
11
+ import base64
12
+
13
+
14
+ @st.cache_data
15
+ def render_svg(svg):
16
+ """Renders the given svg string."""
17
+ b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
18
+ html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}", width="40%"/> </p>'
19
+ c = st.container()
20
+ c.write(html, unsafe_allow_html=True)
21
+
22
+
23
+ @st.cache_resource
24
+ def load_model(model_name):
25
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
26
+ return model
27
+
28
+ @st.cache_resource
29
+ def load_tokenizer(model_name):
30
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
31
+ return tokenizer
32
+
33
+ with st.spinner(text="Please wait while the model is loading...."):
34
+
35
+ model = load_model('nafisehNik/girt-t5-base')
36
+ tokenizer = load_tokenizer('nafisehNik/girt-t5-base')
37
+
38
+
39
+ def compute(sample, num_beams, length_penalty, early_stopping, max_length, min_length):
40
+
41
+ inputs = tokenizer(sample, return_tensors="pt").to('cpu')
42
+
43
+ outputs = model.generate(
44
+ **inputs,
45
+ num_beams=num_beams,
46
+ num_return_sequences=1,
47
+ length_penalty=length_penalty,
48
+ no_repeat_ngram_size=2,
49
+ early_stopping=early_stopping,
50
+ max_length=max_length,
51
+ min_length=min_length).to('cpu')
52
+
53
+ generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False)
54
+ generated_text = generated_texts[0]
55
+
56
+ replace_dict = {
57
+ '\n ': '\n',
58
+ '</s>': '',
59
+ '<pad> ': '',
60
+ '<pad>': '',
61
+ '<unk>': ''
62
+ }
63
+
64
+ postprocess_text = generated_text
65
+ for key, value in replace_dict.items():
66
+ postprocess_text = postprocess_text.replace(key, value)
67
+
68
+
69
+ return postprocess_text
70
+
71
+
72
+ st.markdown("[![Duplicate Space](https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14)](https://huggingface.co/spaces/nafisehNik/girt-space?duplicate=true)")
73
+
74
+ render_svg(open("assets/logo.svg").read())
75
+
76
+ tab1, tab2 = st.tabs(["Design GitHub Issue Template", "Manual Prompt"])
77
+
78
+ with tab1:
79
+ pass
80
+
81
+ with tab2:
82
+
83
+ sent = st.text_input(
84
+ "Sentence:", placeholder="Enter a prompt.", on_change=None
85
+ )
86
+
87
+ # TODO: Check if this is needed!
88
+ clicked = st.button("Submit")
89
+
90
+ if sent:
91
+ res = compute(sent, num_beams=2, length_penalty=1.0, early_stopping=True, max_length=300, min_length=20)
92
+ st.code(res, language="python")
93
+
94
+
assets/logo.svg ADDED
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ transformers>=4.35.0,<4.45.0