Hansimov commited on
Commit
62d5db7
1 Parent(s): 64105c5

:gem: [Feature] HuggingchatStreamer: Enable chat_return_generator

Browse files
Files changed (1) hide show
  1. networks/huggingchat_streamer.py +44 -6
networks/huggingchat_streamer.py CHANGED
@@ -165,9 +165,8 @@ class HuggingchatRequester:
165
  else:
166
  logger_func = logger.warn
167
 
168
- logger_func(status_code_str)
169
-
170
  logger.enter_quiet(not verbose)
 
171
 
172
  if status_code != 200:
173
  logger_func(res.text)
@@ -211,6 +210,7 @@ class HuggingchatRequester:
211
  checker = TokenChecker(input_str=system_prompt + input_prompt, model=self.model)
212
  checker.check_token_limit()
213
 
 
214
  self.get_hf_chat_id()
215
  self.get_conversation_id(system_prompt=system_prompt)
216
  message_id = self.get_last_message_id()
@@ -232,6 +232,7 @@ class HuggingchatRequester:
232
  "web_search": False,
233
  }
234
  self.log_request(request_url, method="POST")
 
235
 
236
  res = requests.post(
237
  request_url,
@@ -256,13 +257,50 @@ class HuggingchatStreamer:
256
  def chat_response(self, messages: list[dict], verbose=False):
257
  requester = HuggingchatRequester(model=self.model)
258
  return requester.chat_completions(
259
- messages=messages, iter_lines=True, verbose=True
260
  )
261
 
262
- def chat_return_dict(self, stream_response):
263
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
- def chat_return_generator(self, stream_response):
 
 
 
266
  pass
267
 
268
 
 
165
  else:
166
  logger_func = logger.warn
167
 
 
 
168
  logger.enter_quiet(not verbose)
169
+ logger_func(status_code_str)
170
 
171
  if status_code != 200:
172
  logger_func(res.text)
 
210
  checker = TokenChecker(input_str=system_prompt + input_prompt, model=self.model)
211
  checker.check_token_limit()
212
 
213
+ logger.enter_quiet(not verbose)
214
  self.get_hf_chat_id()
215
  self.get_conversation_id(system_prompt=system_prompt)
216
  message_id = self.get_last_message_id()
 
232
  "web_search": False,
233
  }
234
  self.log_request(request_url, method="POST")
235
+ logger.exit_quiet(not verbose)
236
 
237
  res = requests.post(
238
  request_url,
 
257
  def chat_response(self, messages: list[dict], verbose=False):
258
  requester = HuggingchatRequester(model=self.model)
259
  return requester.chat_completions(
260
+ messages=messages, iter_lines=False, verbose=verbose
261
  )
262
 
263
+ def chat_return_generator(self, stream_response: requests.Response, verbose=False):
264
+ is_finished = False
265
+ for line in stream_response.iter_lines():
266
+ line = line.decode("utf-8")
267
+ line = re.sub(r"^data:\s*", "", line)
268
+ line = line.strip()
269
+ if not line:
270
+ continue
271
+ try:
272
+ data = json.loads(line, strict=False)
273
+ msg_type = data.get("type")
274
+ if msg_type == "status":
275
+ msg_status = data.get("status")
276
+ continue
277
+ elif msg_type == "stream":
278
+ content_type = "Completions"
279
+ content = data.get("token", "")
280
+ if verbose:
281
+ logger.success(content, end="")
282
+ elif msg_type == "finalAnswer":
283
+ content_type = "Finished"
284
+ content = ""
285
+ full_content = data.get("text")
286
+ if verbose:
287
+ logger.success("\n[Finished]")
288
+ is_finished = True
289
+ break
290
+ else:
291
+ continue
292
+ except Exception as e:
293
+ logger.warn(e)
294
+
295
+ output = self.message_outputer.output(
296
+ content=content, content_type=content_type
297
+ )
298
+ yield output
299
 
300
+ if not is_finished:
301
+ yield self.message_outputer.output(content="", content_type="Finished")
302
+
303
+ def chat_return_dict(self, stream_response):
304
  pass
305
 
306