import os import sys from llamafactory.chat import ChatModel from llamafactory.extras.misc import torch_gc from dotenv import find_dotenv, load_dotenv found_dotenv = find_dotenv(".env") if len(found_dotenv) == 0: found_dotenv = find_dotenv(".env.example") print(f"loading env vars from: {found_dotenv}") load_dotenv(found_dotenv, override=False) path = os.path.dirname(found_dotenv) print(f"Adding {path} to sys.path") sys.path.append(path) from llm_toolkit.translation_engine import * from llm_toolkit.translation_utils import * model_name = os.getenv("MODEL_NAME") load_in_4bit = os.getenv("LOAD_IN_4BIT") == "true" eval_base_model = os.getenv("EVAL_BASE_MODEL") == "true" eval_fine_tuned = os.getenv("EVAL_FINE_TUNED") == "true" save_fine_tuned_model = os.getenv("SAVE_FINE_TUNED") == "true" num_train_epochs = int(os.getenv("NUM_TRAIN_EPOCHS") or 0) data_path = os.getenv("DATA_PATH") results_path = os.getenv("RESULTS_PATH") max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! dtype = ( None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ ) print( model_name, load_in_4bit, max_seq_length, num_train_epochs, dtype, data_path, results_path, eval_base_model, eval_fine_tuned, save_fine_tuned_model, ) adapter_name_or_path = ( sys.argv[1] if len(sys.argv) > 1 else "llama-factory/saves/qwen2-0.5b/lora/sft/checkpoint-560" ) args = dict( model_name_or_path=model_name, # use bnb-4bit-quantized Llama-3-8B-Instruct model adapter_name_or_path=adapter_name_or_path, # load the saved LoRA adapters template="chatml", # same to the one in training finetuning_type="lora", # same to the one in training quantization_bit=4, # load 4-bit quantized model ) chat_model = ChatModel(args) messages = [] print( "Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application." ) while True: query = input("\nUser: ") if query.strip() == "exit": break if query.strip() == "clear": messages = [] torch_gc() print("History has been removed.") continue messages.append({"role": "user", "content": query}) print("Assistant: ", end="", flush=True) response = "" for new_text in chat_model.stream_chat(messages): print(new_text, end="", flush=True) response += new_text print() messages.append({"role": "assistant", "content": response}) torch_gc()