|
|
|
|
|
|
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer |
|
|
|
import datetime |
|
import os |
|
from threading import Event, Thread |
|
from uuid import uuid4 |
|
|
|
import requests |
|
|
|
|
|
model_name = './nyc-savvy' |
|
m = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
if 'llama' in model_name or 'savvy' in model_name: |
|
tok = LlamaTokenizer.from_pretrained(model_name) |
|
else: |
|
tok = AutoTokenizer.from_pretrained(model_name) |
|
tok.bos_token_id = 1 |
|
stop_token_ids = [0] |
|
|
|
class StopOnTokens(StoppingCriteria): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
for stop_id in stop_token_ids: |
|
if input_ids[0][-1] == stop_id: |
|
return True |
|
return False |
|
stop = StopOnTokens() |
|
|
|
max_new_tokens = 1536 |
|
|
|
messages = "A chat between a curious human and an assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" |
|
messages += "### Human: What museums should I visit? - My kids are aged 12 and 5" |
|
messages += "### Assistant: " |
|
|
|
input_ids = tok(messages, return_tensors="pt").input_ids |
|
input_ids = input_ids.to(m.device) |
|
|
|
temperature = 0.7 |
|
top_p = 0.9 |
|
top_k = 0 |
|
repetition_penalty = 1.1 |
|
|
|
op = m.generate( |
|
input_ids=input_ids, |
|
max_new_tokens=100, |
|
temperature=temperature, |
|
do_sample=temperature > 0.0, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
stopping_criteria=StoppingCriteriaList([stop]), |
|
) |
|
for line in op: |
|
print(tok.decode(line)) |
|
|