Update README.md
Browse files
README.md
CHANGED
@@ -23,16 +23,6 @@ from peft import PeftModel, PeftConfig
|
|
23 |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
24 |
|
25 |
MODEL_NAME = "IlyaGusev/saiga_7b_lora"
|
26 |
-
|
27 |
-
config = PeftConfig.from_pretrained(MODEL_NAME)
|
28 |
-
model = AutoModelForCausalLM.from_pretrained(
|
29 |
-
config.base_model_name_or_path,
|
30 |
-
load_in_8bit=True,
|
31 |
-
device_map="auto"
|
32 |
-
)
|
33 |
-
model = PeftModel.from_pretrained(model, MODEL_NAME)
|
34 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
35 |
-
|
36 |
DEFAULT_MESSAGE_TEMPLATE = "<s>{role}\n{content}</s>\n"
|
37 |
DEFAULT_SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
|
38 |
|
@@ -91,9 +81,14 @@ def generate(model, tokenizer, prompt, generation_config):
|
|
91 |
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
92 |
return output.strip()
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
|
98 |
print(generation_config)
|
99 |
|
|
|
23 |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
24 |
|
25 |
MODEL_NAME = "IlyaGusev/saiga_7b_lora"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
DEFAULT_MESSAGE_TEMPLATE = "<s>{role}\n{content}</s>\n"
|
27 |
DEFAULT_SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
|
28 |
|
|
|
81 |
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
82 |
return output.strip()
|
83 |
|
84 |
+
config = PeftConfig.from_pretrained(MODEL_NAME)
|
85 |
+
model = AutoModelForCausalLM.from_pretrained(
|
86 |
+
config.base_model_name_or_path,
|
87 |
+
load_in_8bit=True,
|
88 |
+
device_map="auto"
|
89 |
+
)
|
90 |
+
model = PeftModel.from_pretrained(model, MODEL_NAME)
|
91 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
92 |
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
|
93 |
print(generation_config)
|
94 |
|