nawhgnuj commited on
Commit
c86d108
1 Parent(s): 593d7cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -29
app.py CHANGED
@@ -73,13 +73,6 @@ def stream_chat(
73
  Importantly, always respond to points in Trump's style. Keep responses concise and avoid unnecessary repetition.
74
  """
75
 
76
- temperature = 0.1
77
- max_new_tokens = 256
78
- top_p = 0.9
79
- top_k = 20
80
- repetition_penalty = 1.5
81
- no_repeat_ngram_size = 3
82
-
83
  conversation = [
84
  {"role": "system", "content": system_prompt}
85
  ]
@@ -95,29 +88,21 @@ def stream_chat(
95
 
96
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
97
 
98
- generate_kwargs = dict(
99
- input_ids=input_ids,
100
- attention_mask=attention_mask,
101
- max_new_tokens=max_new_tokens,
102
- do_sample=True,
103
- top_p=top_p,
104
- top_k=top_k,
105
- temperature=temperature,
106
- repetition_penalty=repetition_penalty,
107
- no_repeat_ngram_size = no_repeat_ngram_size,
108
- pad_token_id=tokenizer.pad_token_id,
109
- eos_token_id=tokenizer.eos_token_id,
110
- streamer=streamer,
111
- )
112
-
113
  with torch.no_grad():
114
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
115
- thread.start()
116
-
117
- buffer = ""
118
- for new_text in streamer:
119
- buffer += new_text
120
- yield buffer
 
 
 
 
 
 
 
121
 
122
  def add_text(history, text):
123
  history = history + [(text, None)]
 
73
  Importantly, always respond to points in Trump's style. Keep responses concise and avoid unnecessary repetition.
74
  """
75
 
 
 
 
 
 
 
 
76
  conversation = [
77
  {"role": "system", "content": system_prompt}
78
  ]
 
88
 
89
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  with torch.no_grad():
92
+ output = model.generate(
93
+ input_ids=input_ids,
94
+ attention_mask=attention_mask,
95
+ max_new_tokens=1024,
96
+ do_sample=True,
97
+ top_p=1.0,
98
+ top_k=20,
99
+ temperature=0.8,
100
+ pad_token_id=tokenizer.pad_token_id,
101
+ eos_token_id=tokenizer.eos_token_id,
102
+ )
103
+
104
+ response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
105
+ return response.strip()
106
 
107
  def add_text(history, text):
108
  history = history + [(text, None)]