Spaces:
Running
Running
update model_worker.py
Browse files- model_worker.py +15 -14
model_worker.py
CHANGED
@@ -325,7 +325,8 @@ class ModelWorker:
|
|
325 |
"queue_length": self.get_queue_length(),
|
326 |
}
|
327 |
|
328 |
-
@torch.inference_mode()
|
|
|
329 |
def generate_stream(self, params):
|
330 |
system_message = params["prompt"][0]["content"]
|
331 |
send_messages = params["prompt"][1:]
|
@@ -428,18 +429,19 @@ class ModelWorker:
|
|
428 |
)
|
429 |
logger.info(f"Generation config: {generation_config}")
|
430 |
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
|
|
443 |
|
444 |
generated_text = ""
|
445 |
for new_text in streamer:
|
@@ -453,7 +455,6 @@ class ModelWorker:
|
|
453 |
)
|
454 |
self.model.system_message = old_system_message
|
455 |
|
456 |
-
@spaces.GPU(duration=120)
|
457 |
def generate_stream_gate(self, params):
|
458 |
try:
|
459 |
for x in self.generate_stream(params):
|
|
|
325 |
"queue_length": self.get_queue_length(),
|
326 |
}
|
327 |
|
328 |
+
# @torch.inference_mode()
|
329 |
+
@spaces.GPU(duration=120)
|
330 |
def generate_stream(self, params):
|
331 |
system_message = params["prompt"][0]["content"]
|
332 |
send_messages = params["prompt"][1:]
|
|
|
429 |
)
|
430 |
logger.info(f"Generation config: {generation_config}")
|
431 |
|
432 |
+
with torch.no_grad():
|
433 |
+
thread = Thread(
|
434 |
+
target=self.model.chat,
|
435 |
+
kwargs=dict(
|
436 |
+
tokenizer=self.tokenizer,
|
437 |
+
pixel_values=pixel_values,
|
438 |
+
question=question,
|
439 |
+
history=history,
|
440 |
+
return_history=False,
|
441 |
+
generation_config=generation_config,
|
442 |
+
),
|
443 |
+
)
|
444 |
+
thread.start()
|
445 |
|
446 |
generated_text = ""
|
447 |
for new_text in streamer:
|
|
|
455 |
)
|
456 |
self.model.system_message = old_system_message
|
457 |
|
|
|
458 |
def generate_stream_gate(self, params):
|
459 |
try:
|
460 |
for x in self.generate_stream(params):
|