Update README.md
Browse files
README.md
CHANGED
@@ -37,21 +37,15 @@ class Conversation:
|
|
37 |
self,
|
38 |
message_template=DEFAULT_MESSAGE_TEMPLATE,
|
39 |
system_prompt=DEFAULT_SYSTEM_PROMPT,
|
40 |
-
|
41 |
):
|
42 |
self.message_template = message_template
|
43 |
-
self.
|
44 |
self.messages = [{
|
45 |
"role": "system",
|
46 |
"content": system_prompt
|
47 |
}]
|
48 |
|
49 |
-
def get_start_token_id(self):
|
50 |
-
return self.start_token_id
|
51 |
-
|
52 |
-
def get_bot_token_id(self):
|
53 |
-
return self.bot_token_id
|
54 |
-
|
55 |
def add_user_message(self, message):
|
56 |
self.messages.append({
|
57 |
"role": "user",
|
@@ -69,12 +63,12 @@ class Conversation:
|
|
69 |
for message in self.messages:
|
70 |
message_text = self.message_template.format(**message)
|
71 |
final_text += message_text
|
72 |
-
final_text +=
|
73 |
return final_text.strip()
|
74 |
|
75 |
|
76 |
def generate(model, tokenizer, prompt, generation_config):
|
77 |
-
data = tokenizer(prompt, return_tensors="pt")
|
78 |
data = {k: v.to(model.device) for k, v in data.items()}
|
79 |
output_ids = model.generate(
|
80 |
**data,
|
|
|
37 |
self,
|
38 |
message_template=DEFAULT_MESSAGE_TEMPLATE,
|
39 |
system_prompt=DEFAULT_SYSTEM_PROMPT,
|
40 |
+
response_template=DEFAULT_RESPONSE_TEMPLATE
|
41 |
):
|
42 |
self.message_template = message_template
|
43 |
+
self.response_template = response_template
|
44 |
self.messages = [{
|
45 |
"role": "system",
|
46 |
"content": system_prompt
|
47 |
}]
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def add_user_message(self, message):
|
50 |
self.messages.append({
|
51 |
"role": "user",
|
|
|
63 |
for message in self.messages:
|
64 |
message_text = self.message_template.format(**message)
|
65 |
final_text += message_text
|
66 |
+
final_text += DEFAULT_RESPONSE_TEMPLATE
|
67 |
return final_text.strip()
|
68 |
|
69 |
|
70 |
def generate(model, tokenizer, prompt, generation_config):
|
71 |
+
data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
72 |
data = {k: v.to(model.device) for k, v in data.items()}
|
73 |
output_ids = model.generate(
|
74 |
**data,
|