import sys from pathlib import Path from mistral_inference.generate import generate from mistral_inference.model import Transformer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer def run_chat(model_path: str, prompt: str, max_tokens: int = 256, temperature: float = 1.0, instruct: bool = True, lora_path: str = None): # Find the correct tokenizer file model_path = Path(model_path) tokenizer_file = model_path / "tokenizer.model.v3" if not tokenizer_file.is_file(): raise FileNotFoundError(f"Tokenizer model file not found at {tokenizer_file}") mistral_tokenizer = MistralTokenizer.from_file(str(tokenizer_file)) tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer transformer = Transformer.from_folder( model_path, max_batch_size=3, num_pipeline_ranks=1 ) if lora_path is not None: transformer.load_lora(Path(lora_path)) tokens = tokenizer.encode(prompt, bos=True, eos=False) generated_tokens, _ = generate( [tokens], transformer, max_tokens=max_tokens, temperature=temperature, eos_id=tokenizer.eos_id, ) answer = tokenizer.decode(generated_tokens[0]) print(answer) if __name__ == "__main__": import fire fire.Fire(run_chat)