Spaces:
Build error
Build error
File size: 2,520 Bytes
5860b41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import os
import sys
from llamafactory.chat import ChatModel
from llamafactory.extras.misc import torch_gc
from dotenv import find_dotenv, load_dotenv
found_dotenv = find_dotenv(".env")
if len(found_dotenv) == 0:
found_dotenv = find_dotenv(".env.example")
print(f"loading env vars from: {found_dotenv}")
load_dotenv(found_dotenv, override=False)
path = os.path.dirname(found_dotenv)
print(f"Adding {path} to sys.path")
sys.path.append(path)
from llm_toolkit.translation_engine import *
from llm_toolkit.translation_utils import *
model_name = os.getenv("MODEL_NAME")
load_in_4bit = os.getenv("LOAD_IN_4BIT") == "true"
eval_base_model = os.getenv("EVAL_BASE_MODEL") == "true"
eval_fine_tuned = os.getenv("EVAL_FINE_TUNED") == "true"
save_fine_tuned_model = os.getenv("SAVE_FINE_TUNED") == "true"
num_train_epochs = int(os.getenv("NUM_TRAIN_EPOCHS") or 0)
data_path = os.getenv("DATA_PATH")
results_path = os.getenv("RESULTS_PATH")
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = (
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
)
print(
model_name,
load_in_4bit,
max_seq_length,
num_train_epochs,
dtype,
data_path,
results_path,
eval_base_model,
eval_fine_tuned,
save_fine_tuned_model,
)
adapter_name_or_path = (
sys.argv[1]
if len(sys.argv) > 1
else "llama-factory/saves/qwen2-0.5b/lora/sft/checkpoint-560"
)
args = dict(
model_name_or_path=model_name, # use bnb-4bit-quantized Llama-3-8B-Instruct model
adapter_name_or_path=adapter_name_or_path, # load the saved LoRA adapters
template="chatml", # same to the one in training
finetuning_type="lora", # same to the one in training
quantization_bit=4, # load 4-bit quantized model
)
chat_model = ChatModel(args)
messages = []
print(
"Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application."
)
while True:
query = input("\nUser: ")
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
torch_gc()
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "assistant", "content": response})
torch_gc()
|