Spaces:
Build error
Build error
import os | |
import sys | |
from llamafactory.chat import ChatModel | |
from llamafactory.extras.misc import torch_gc | |
from dotenv import find_dotenv, load_dotenv | |
found_dotenv = find_dotenv(".env") | |
if len(found_dotenv) == 0: | |
found_dotenv = find_dotenv(".env.example") | |
print(f"loading env vars from: {found_dotenv}") | |
load_dotenv(found_dotenv, override=False) | |
path = os.path.dirname(found_dotenv) | |
print(f"Adding {path} to sys.path") | |
sys.path.append(path) | |
from llm_toolkit.translation_engine import * | |
from llm_toolkit.translation_utils import * | |
model_name = os.getenv("MODEL_NAME") | |
load_in_4bit = os.getenv("LOAD_IN_4BIT") == "true" | |
eval_base_model = os.getenv("EVAL_BASE_MODEL") == "true" | |
eval_fine_tuned = os.getenv("EVAL_FINE_TUNED") == "true" | |
save_fine_tuned_model = os.getenv("SAVE_FINE_TUNED") == "true" | |
num_train_epochs = int(os.getenv("NUM_TRAIN_EPOCHS") or 0) | |
data_path = os.getenv("DATA_PATH") | |
results_path = os.getenv("RESULTS_PATH") | |
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! | |
dtype = ( | |
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ | |
) | |
print( | |
model_name, | |
load_in_4bit, | |
max_seq_length, | |
num_train_epochs, | |
dtype, | |
data_path, | |
results_path, | |
eval_base_model, | |
eval_fine_tuned, | |
save_fine_tuned_model, | |
) | |
adapter_name_or_path = ( | |
sys.argv[1] | |
if len(sys.argv) > 1 | |
else "llama-factory/saves/qwen2-0.5b/lora/sft/checkpoint-560" | |
) | |
args = dict( | |
model_name_or_path=model_name, # use bnb-4bit-quantized Llama-3-8B-Instruct model | |
adapter_name_or_path=adapter_name_or_path, # load the saved LoRA adapters | |
template="chatml", # same to the one in training | |
finetuning_type="lora", # same to the one in training | |
quantization_bit=4, # load 4-bit quantized model | |
) | |
chat_model = ChatModel(args) | |
messages = [] | |
print( | |
"Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application." | |
) | |
while True: | |
query = input("\nUser: ") | |
if query.strip() == "exit": | |
break | |
if query.strip() == "clear": | |
messages = [] | |
torch_gc() | |
print("History has been removed.") | |
continue | |
messages.append({"role": "user", "content": query}) | |
print("Assistant: ", end="", flush=True) | |
response = "" | |
for new_text in chat_model.stream_chat(messages): | |
print(new_text, end="", flush=True) | |
response += new_text | |
print() | |
messages.append({"role": "assistant", "content": response}) | |
torch_gc() | |