dh-mc's picture
initial code for Chinese/English translation
3860729
raw
history blame
2.52 kB
import os
import sys
from llamafactory.chat import ChatModel
from llamafactory.extras.misc import torch_gc
from dotenv import find_dotenv, load_dotenv
found_dotenv = find_dotenv(".env")
if len(found_dotenv) == 0:
found_dotenv = find_dotenv(".env.example")
print(f"loading env vars from: {found_dotenv}")
load_dotenv(found_dotenv, override=False)
path = os.path.dirname(found_dotenv)
print(f"Adding {path} to sys.path")
sys.path.append(path)
from llm_toolkit.translation_engine import *
from llm_toolkit.translation_utils import *
model_name = os.getenv("MODEL_NAME")
load_in_4bit = os.getenv("LOAD_IN_4BIT") == "true"
eval_base_model = os.getenv("EVAL_BASE_MODEL") == "true"
eval_fine_tuned = os.getenv("EVAL_FINE_TUNED") == "true"
save_fine_tuned_model = os.getenv("SAVE_FINE_TUNED") == "true"
num_train_epochs = int(os.getenv("NUM_TRAIN_EPOCHS") or 0)
data_path = os.getenv("DATA_PATH")
results_path = os.getenv("RESULTS_PATH")
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = (
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
)
print(
model_name,
load_in_4bit,
max_seq_length,
num_train_epochs,
dtype,
data_path,
results_path,
eval_base_model,
eval_fine_tuned,
save_fine_tuned_model,
)
adapter_name_or_path = (
sys.argv[1]
if len(sys.argv) > 1
else "llama-factory/saves/qwen2-0.5b/lora/sft/checkpoint-560"
)
args = dict(
model_name_or_path=model_name, # use bnb-4bit-quantized Llama-3-8B-Instruct model
adapter_name_or_path=adapter_name_or_path, # load the saved LoRA adapters
template="chatml", # same to the one in training
finetuning_type="lora", # same to the one in training
quantization_bit=4, # load 4-bit quantized model
)
chat_model = ChatModel(args)
messages = []
print(
"Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application."
)
while True:
query = input("\nUser: ")
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
torch_gc()
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "assistant", "content": response})
torch_gc()