Spaces:
Running
on
Zero
Running
on
Zero
Update multipurpose_chatbot/engines/transformers_engine.py
Browse files
multipurpose_chatbot/engines/transformers_engine.py
CHANGED
@@ -429,7 +429,7 @@ class TransformersEngine(BaseEngine):
|
|
429 |
|
430 |
# ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
|
431 |
import sys
|
432 |
-
|
433 |
with torch.no_grad():
|
434 |
inputs = self.tokenizer(prompt, return_tensors='pt')
|
435 |
num_tokens = inputs.input_ids.size(1)
|
@@ -450,7 +450,7 @@ class TransformersEngine(BaseEngine):
|
|
450 |
out_tokens.extend(token.tolist())
|
451 |
response = self.tokenizer.decode(out_tokens)
|
452 |
if "<|im_start|>assistant\n" in response:
|
453 |
-
response = response.split("<|im_start|>assistant\n")
|
454 |
num_tokens += 1
|
455 |
print(f"{response}", end='\r')
|
456 |
sys.stdout.flush()
|
@@ -458,7 +458,7 @@ class TransformersEngine(BaseEngine):
|
|
458 |
|
459 |
if response is not None:
|
460 |
if "<|im_start|>assistant\n" in response:
|
461 |
-
response = response.split("<|im_start|>assistant\n")
|
462 |
full_text = prompt + response
|
463 |
num_tokens = len(self.tokenizer.encode(full_text))
|
464 |
yield response, num_tokens
|
|
|
429 |
|
430 |
# ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
|
431 |
import sys
|
432 |
+
self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
433 |
with torch.no_grad():
|
434 |
inputs = self.tokenizer(prompt, return_tensors='pt')
|
435 |
num_tokens = inputs.input_ids.size(1)
|
|
|
450 |
out_tokens.extend(token.tolist())
|
451 |
response = self.tokenizer.decode(out_tokens)
|
452 |
if "<|im_start|>assistant\n" in response:
|
453 |
+
response = response.split("<|im_start|>assistant\n")[-1]
|
454 |
num_tokens += 1
|
455 |
print(f"{response}", end='\r')
|
456 |
sys.stdout.flush()
|
|
|
458 |
|
459 |
if response is not None:
|
460 |
if "<|im_start|>assistant\n" in response:
|
461 |
+
response = response.split("<|im_start|>assistant\n")[-1]
|
462 |
full_text = prompt + response
|
463 |
num_tokens = len(self.tokenizer.encode(full_text))
|
464 |
yield response, num_tokens
|