amazingvince commited on
Commit
dec7817
1 Parent(s): 7e2d4ba

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +110 -0
README.md CHANGED
@@ -1,3 +1,113 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ Included is code ripped from fastchat with the expected chat templating.
6
+
7
+ ```python
8
+ import dataclasses
9
+ from enum import auto, Enum
10
+ from typing import List, Tuple, Any
11
+
12
+
13
+ class SeparatorStyle(Enum):
14
+ """Different separator style."""
15
+ SINGLE = auto()
16
+ TWO = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that keeps all conversation history."""
22
+ system: str
23
+ roles: List[str]
24
+ messages: List[List[str]]
25
+ offset: int
26
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
27
+ sep: str = "###"
28
+ sep2: str = None
29
+
30
+ # Used for gradio server
31
+ skip_next: bool = False
32
+ conv_id: Any = None
33
+
34
+ def get_prompt(self):
35
+ if self.sep_style == SeparatorStyle.SINGLE:
36
+ ret = self.system
37
+ for role, message in self.messages:
38
+ if message:
39
+ ret += self.sep + " " + role + ": " + message
40
+ else:
41
+ ret += self.sep + " " + role + ":"
42
+ return ret
43
+ elif self.sep_style == SeparatorStyle.TWO:
44
+ seps = [self.sep, self.sep2]
45
+ ret = self.system + seps[0]
46
+ for i, (role, message) in enumerate(self.messages):
47
+ if message:
48
+ ret += role + ": " + message + seps[i % 2]
49
+ else:
50
+ ret += role + ":"
51
+ return ret
52
+ else:
53
+ raise ValueError(f"Invalid style: {self.sep_style}")
54
+
55
+ def append_message(self, role, message):
56
+ self.messages.append([role, message])
57
+
58
+ def to_gradio_chatbot(self):
59
+ ret = []
60
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
61
+ if i % 2 == 0:
62
+ ret.append([msg, None])
63
+ else:
64
+ ret[-1][-1] = msg
65
+ return ret
66
+
67
+ def copy(self):
68
+ return Conversation(
69
+ system=self.system,
70
+ roles=self.roles,
71
+ messages=[[x, y] for x, y in self.messages],
72
+ offset=self.offset,
73
+ sep_style=self.sep_style,
74
+ sep=self.sep,
75
+ sep2=self.sep2,
76
+ conv_id=self.conv_id)
77
+
78
+ def dict(self):
79
+ return {
80
+ "system": self.system,
81
+ "roles": self.roles,
82
+ "messages": self.messages,
83
+ "offset": self.offset,
84
+ "sep": self.sep,
85
+ "sep2": self.sep2,
86
+ "conv_id": self.conv_id,
87
+ }
88
+
89
+
90
+
91
+ conv = Conversation(
92
+ system="A chat between a curious user and an artificial intelligence assistant. "
93
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
94
+ roles=("USER", "ASSISTANT"),
95
+ messages=[],
96
+ offset=0,
97
+ sep_style=SeparatorStyle.TWO,
98
+ sep=" ",
99
+ sep2="</s>",
100
+ )
101
+
102
+ conv.append_message(conv.roles[0], "Why would Microsoft take this down?")
103
+ conv.append_message(conv.roles[1], None)
104
+ prompt = conv.get_prompt()
105
+
106
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
107
+
108
+ result = model.generate(**inputs, max_new_tokens=1000)
109
+ generated_ids = result[0]
110
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
111
+ print(generated_text)
112
+
113
+ ```