from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
from mistral_common.protocol.instruct.messages import UserMessage | |
from mistral_common.protocol.instruct.request import ChatCompletionRequest | |
from transformers import AutoModelForCausalLM | |
import torch | |
# Load Mistral tokenizer | |
model_name = "nemostral" | |
tokenizer = MistralTokenizer.from_model(model_name) | |
# Tokenize a list of messages | |
tokenized = tokenizer.encode_chat_completion( | |
ChatCompletionRequest( | |
messages=[ | |
UserMessage(content="How many peolpe live in France and all its neighbours? List all of them!") | |
], | |
model=model_name, | |
) | |
) | |
tokens, text = tokenized.tokens, tokenized.text | |
input_ids = torch.tensor([tokens]).to("cuda") | |
model = AutoModelForCausalLM.from_pretrained("./", torch_dtype=torch.bfloat16).to("cuda") | |
out = model.generate(input_ids, max_new_tokens=1024) | |
generated = out[0, input_ids.shape[-1]:-1].tolist() | |
print(tokenizer.decode(generated)) | |