File size: 2,520 Bytes
5860b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import sys
from llamafactory.chat import ChatModel
from llamafactory.extras.misc import torch_gc

from dotenv import find_dotenv, load_dotenv

found_dotenv = find_dotenv(".env")

if len(found_dotenv) == 0:
    found_dotenv = find_dotenv(".env.example")
print(f"loading env vars from: {found_dotenv}")
load_dotenv(found_dotenv, override=False)

path = os.path.dirname(found_dotenv)
print(f"Adding {path} to sys.path")
sys.path.append(path)

from llm_toolkit.translation_engine import *
from llm_toolkit.translation_utils import *

model_name = os.getenv("MODEL_NAME")
load_in_4bit = os.getenv("LOAD_IN_4BIT") == "true"
eval_base_model = os.getenv("EVAL_BASE_MODEL") == "true"
eval_fine_tuned = os.getenv("EVAL_FINE_TUNED") == "true"
save_fine_tuned_model = os.getenv("SAVE_FINE_TUNED") == "true"
num_train_epochs = int(os.getenv("NUM_TRAIN_EPOCHS") or 0)
data_path = os.getenv("DATA_PATH")
results_path = os.getenv("RESULTS_PATH")

max_seq_length = 2048  # Choose any! We auto support RoPE Scaling internally!
dtype = (
    None  # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
)

print(
    model_name,
    load_in_4bit,
    max_seq_length,
    num_train_epochs,
    dtype,
    data_path,
    results_path,
    eval_base_model,
    eval_fine_tuned,
    save_fine_tuned_model,
)

adapter_name_or_path = (
    sys.argv[1]
    if len(sys.argv) > 1
    else "llama-factory/saves/qwen2-0.5b/lora/sft/checkpoint-560"
)

args = dict(
    model_name_or_path=model_name,  # use bnb-4bit-quantized Llama-3-8B-Instruct model
    adapter_name_or_path=adapter_name_or_path,  # load the saved LoRA adapters
    template="chatml",  # same to the one in training
    finetuning_type="lora",  # same to the one in training
    quantization_bit=4,  # load 4-bit quantized model
)
chat_model = ChatModel(args)

messages = []
print(
    "Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application."
)
while True:
    query = input("\nUser: ")
    if query.strip() == "exit":
        break
    if query.strip() == "clear":
        messages = []
        torch_gc()
        print("History has been removed.")
        continue

    messages.append({"role": "user", "content": query})
    print("Assistant: ", end="", flush=True)

    response = ""
    for new_text in chat_model.stream_chat(messages):
        print(new_text, end="", flush=True)
        response += new_text
    print()
    messages.append({"role": "assistant", "content": response})

torch_gc()