zyliu commited on
Commit
192cdc0
1 Parent(s): 9ec7317

update model_worker.py

Browse files
Files changed (1) hide show
  1. model_worker.py +15 -14
model_worker.py CHANGED
@@ -325,7 +325,8 @@ class ModelWorker:
325
  "queue_length": self.get_queue_length(),
326
  }
327
 
328
- @torch.inference_mode()
 
329
  def generate_stream(self, params):
330
  system_message = params["prompt"][0]["content"]
331
  send_messages = params["prompt"][1:]
@@ -428,18 +429,19 @@ class ModelWorker:
428
  )
429
  logger.info(f"Generation config: {generation_config}")
430
 
431
- thread = Thread(
432
- target=self.model.chat,
433
- kwargs=dict(
434
- tokenizer=self.tokenizer,
435
- pixel_values=pixel_values,
436
- question=question,
437
- history=history,
438
- return_history=False,
439
- generation_config=generation_config,
440
- ),
441
- )
442
- thread.start()
 
443
 
444
  generated_text = ""
445
  for new_text in streamer:
@@ -453,7 +455,6 @@ class ModelWorker:
453
  )
454
  self.model.system_message = old_system_message
455
 
456
- @spaces.GPU(duration=120)
457
  def generate_stream_gate(self, params):
458
  try:
459
  for x in self.generate_stream(params):
 
325
  "queue_length": self.get_queue_length(),
326
  }
327
 
328
+ # @torch.inference_mode()
329
+ @spaces.GPU(duration=120)
330
  def generate_stream(self, params):
331
  system_message = params["prompt"][0]["content"]
332
  send_messages = params["prompt"][1:]
 
429
  )
430
  logger.info(f"Generation config: {generation_config}")
431
 
432
+ with torch.no_grad():
433
+ thread = Thread(
434
+ target=self.model.chat,
435
+ kwargs=dict(
436
+ tokenizer=self.tokenizer,
437
+ pixel_values=pixel_values,
438
+ question=question,
439
+ history=history,
440
+ return_history=False,
441
+ generation_config=generation_config,
442
+ ),
443
+ )
444
+ thread.start()
445
 
446
  generated_text = ""
447
  for new_text in streamer:
 
455
  )
456
  self.model.system_message = old_system_message
457
 
 
458
  def generate_stream_gate(self, params):
459
  try:
460
  for x in self.generate_stream(params):