UNIST-Eunchan commited on
Commit
d872090
1 Parent(s): fcaeeb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -34
app.py CHANGED
@@ -52,7 +52,7 @@ def chunking(book_text):
52
  #make next_pseudo_segment
53
  next_pseudo_segment = ""
54
  next_token_len = 0
55
- for t in range(30):
56
  if (i+t < len(sentences)) and (next_token_len + token_lens[i+t] < 512):
57
  next_token_len += token_lens[i+t]
58
  next_pseudo_segment += sentences[i+t]
@@ -69,8 +69,6 @@ def chunking(book_text):
69
  return segments
70
 
71
 
72
- chunked_segments = chunking(test_book[0]['book'])
73
-
74
  '''
75
  '''
76
 
@@ -83,42 +81,37 @@ book_index = st.sidebar.slider("Select Book Example", value = 0,min_value = 0, m
83
  _book = test_book[book_index]['book']
84
  chunked_segments = chunking(_book)
85
 
86
- sent = st.text_area("Text", _book, height = 550)
87
  max_length = st.sidebar.slider("Max Length", value = 512,min_value = 10, max_value=1024)
88
  temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
89
  top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0)
90
  top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.92)
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  for segment in range(len(chunked_segments)):
94
 
95
- encoded_prompt = tokenizer.encode(segment, add_special_tokens=False, return_tensors="pt")
96
- if encoded_prompt.size()[-1] == 0:
97
- input_ids = None
98
- else:
99
- input_ids = encoded_prompt
100
-
101
-
102
- output_sequences = infer(input_ids, max_length, temperature, top_k, top_p)
103
-
104
-
105
- for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
106
- print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
107
- generated_sequences = generated_sequence.tolist()
108
-
109
- # Decode text
110
- text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
111
-
112
- # Remove all text after the stop token
113
- #text = text[: text.find(args.stop_token) if args.stop_token else None]
114
-
115
- # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
116
- total_sequence = (
117
- sent + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
118
- )
119
-
120
- generated_sequences.append(total_sequence)
121
- print(total_sequence)
122
-
123
-
124
- st.write(generated_sequences[-1])
 
52
  #make next_pseudo_segment
53
  next_pseudo_segment = ""
54
  next_token_len = 0
55
+ for t in range(10):
56
  if (i+t < len(sentences)) and (next_token_len + token_lens[i+t] < 512):
57
  next_token_len += token_lens[i+t]
58
  next_pseudo_segment += sentences[i+t]
 
69
  return segments
70
 
71
 
 
 
72
  '''
73
  '''
74
 
 
81
  _book = test_book[book_index]['book']
82
  chunked_segments = chunking(_book)
83
 
84
+ sent = st.text_area("Text", _book[:512], height = 550)
85
  max_length = st.sidebar.slider("Max Length", value = 512,min_value = 10, max_value=1024)
86
  temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
87
  top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0)
88
  top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.92)
89
 
90
+ def generate_output(test_samples):
91
+ inputs = tokenizer(
92
+ test_samples,
93
+ padding=max_length,
94
+ truncation=True,
95
+ max_length=1024,
96
+ return_tensors="pt",
97
+ )
98
+ input_ids = inputs.input_ids.to(model.)
99
+ attention_mask = inputs.attention_mask.to(model)
100
+ outputs = model.generate(input_ids,
101
+ max_length = 256,
102
+ min_length=32,
103
+ top_p = 0.92,
104
+ num_beams=5,
105
+ no_repeat_ngram_size=2,
106
+ attention_mask=attention_mask)
107
+ output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
108
+ return outputs, output_str
109
+
110
+
111
+ chunked_segments = chunking(test_book[0]['book'])
112
+
113
 
114
  for segment in range(len(chunked_segments)):
115
 
116
+ summaries = generate_output(segment)
117
+ st.write(summaries[-1])