cahya commited on
Commit
9177921
1 Parent(s): b466416

add hash function, temperature, random seed

Browse files
Files changed (2) hide show
  1. app/app.py +34 -14
  2. requirements.txt +2 -1
app/app.py CHANGED
@@ -5,6 +5,7 @@ from prompts import PROMPT_LIST
5
  import random
6
  import time
7
  from transformers import pipeline, set_seed
 
8
 
9
  # st.set_page_config(page_title="Image Search")
10
 
@@ -19,13 +20,14 @@ def get_generator():
19
  return text_generator
20
 
21
 
22
- #@st.cache(suppress_st_warning=True)
23
  def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
24
- temperature: float = 1.0, max_time: float = None, seed=42):
25
- # st.write("Cache miss: process")
26
  set_seed(seed)
27
  result = text_generator(text, max_length=max_length, do_sample=do_sample,
28
- top_k=top_k, top_p=top_p, temperature=temperature, max_time=max_time)
 
29
  return result
30
 
31
 
@@ -65,37 +67,55 @@ max_length = st.sidebar.number_input(
65
  help="The maximum length of the sequence to be generated."
66
  )
67
 
68
- temp = st.sidebar.slider(
69
  "Temperature",
70
  value=1.0,
71
  min_value=0.0,
72
- max_value=100.0
73
  )
74
 
75
- top_k = st.sidebar.number_input(
76
- "Top k",
77
- value=25
78
  )
79
 
80
- top_p = st.sidebar.number_input(
81
- "Top p",
82
- value=0.95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
 
 
85
  text_generator = get_generator()
86
  if st.button("Run"):
87
  with st.spinner(text="Getting results..."):
88
  st.subheader("Result")
89
  time_start = time.time()
90
- result = process(text=session_state.text, max_length=int(max_length), top_k=int(top_k), top_p=float(top_p))
 
 
91
  time_end = time.time()
92
  time_diff = time_end-time_start
93
- #print(f"Text generated in {time_diff} seconds")
94
  result = result[0]["generated_text"]
95
  st.write(result.replace("\n", " \n"))
96
  st.text("Translation")
97
  translation = translate(result, "en", "id")
98
  st.write(translation.replace("\n", " \n"))
 
 
99
 
100
  # Reset state
101
  session_state.prompt = None
 
5
  import random
6
  import time
7
  from transformers import pipeline, set_seed
8
+ import tokenizers
9
 
10
  # st.set_page_config(page_title="Image Search")
11
 
 
20
  return text_generator
21
 
22
 
23
+ @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
24
  def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
25
+ temperature: float = 1.0, max_time: float = 10.0, seed=42):
26
+ st.write("Cache miss: process")
27
  set_seed(seed)
28
  result = text_generator(text, max_length=max_length, do_sample=do_sample,
29
+ top_k=top_k, top_p=top_p, temperature=temperature,
30
+ max_time=max_time)
31
  return result
32
 
33
 
 
67
  help="The maximum length of the sequence to be generated."
68
  )
69
 
70
+ temperature = st.sidebar.slider(
71
  "Temperature",
72
  value=1.0,
73
  min_value=0.0,
74
+ max_value=10.0
75
  )
76
 
77
+ do_sample = st.sidebar.checkbox(
78
+ "Use sampling",
79
+ value=True
80
  )
81
 
82
+ top_k = 25
83
+ top_p = 0.95
84
+
85
+ if do_sample:
86
+ top_k = st.sidebar.number_input(
87
+ "Top k",
88
+ value=top_k
89
+ )
90
+ top_p = st.sidebar.number_input(
91
+ "Top p",
92
+ value=top_p
93
+ )
94
+
95
+ seed = st.sidebar.number_input(
96
+ "Random Seed",
97
+ value=25,
98
+ help="The number used to initialize a pseudorandom number generator"
99
  )
100
 
101
+
102
  text_generator = get_generator()
103
  if st.button("Run"):
104
  with st.spinner(text="Getting results..."):
105
  st.subheader("Result")
106
  time_start = time.time()
107
+ result = process(text=session_state.text, max_length=int(max_length),
108
+ temperature=temperature, do_sample=do_sample,
109
+ top_k=int(top_k), top_p=float(top_p), seed=seed)
110
  time_end = time.time()
111
  time_diff = time_end-time_start
 
112
  result = result[0]["generated_text"]
113
  st.write(result.replace("\n", " \n"))
114
  st.text("Translation")
115
  translation = translate(result, "en", "id")
116
  st.write(translation.replace("\n", " \n"))
117
+ # st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
118
+ st.write(f"*Text generated in {time_diff:.5} seconds*")
119
 
120
  # Reset state
121
  session_state.prompt = None
requirements.txt CHANGED
@@ -4,4 +4,5 @@ transformers
4
  datasets
5
  mtranslate
6
  # streamlit version 0.67.1 is needed due to issue with caching
7
- streamlit==0.67.1
 
 
4
  datasets
5
  mtranslate
6
  # streamlit version 0.67.1 is needed due to issue with caching
7
+ # streamlit==0.67.1
8
+ streamlit