nroggendorff commited on
Commit
b2dcefa
·
verified ·
1 Parent(s): 48307dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -29
app.py CHANGED
@@ -5,37 +5,16 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
5
 
6
  torch.set_default_device("cuda")
7
 
8
- tokenizer = AutoTokenizer.from_pretrained(
9
- "cognitivecomputations/dolphin-2.9.1-mixtral-1x22b",
10
- trust_remote_code=True
11
- )
12
-
13
- model = AutoModelForCausalLM.from_pretrained(
14
- "cognitivecomputations/dolphin-2.9.1-mixtral-1x22b",
15
- torch_dtype="auto",
16
- load_in_4bit=True,
17
- trust_remote_code=True
18
- )
19
-
20
- system_prompt = "<|im_start|>system\nYou are Dolphin, a helpful AI assistant.<|im_end|>"
21
 
22
  @spaces.GPU(duration=120)
23
  def predict(message, history):
24
- history_transformer_format = history + [[message, ""]]
25
- messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
26
- input_ids = tokenizer([messages], return_tensors="pt").input_ids
27
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
28
- generate_kwargs = {
29
- 'input_ids': input_ids,
30
- 'streamer': streamer,
31
- 'max_new_tokens': 10000,
32
- 'do_sample': True,
33
- 'top_p': 0.95,
34
- 'top_k': 50,
35
- 'temperature': 0.7,
36
- 'num_beams': 1
37
- }
38
- output = model.generate(**generate_kwargs)
39
- partial_message = streamer.decode(output[0], skip_special_tokens=True)
40
 
41
  gr.ChatInterface(predict).launch()
 
5
 
6
  torch.set_default_device("cuda")
7
 
8
+ pipe = pipeline("text-generation", model="cognitivecomputations/dolphin-2.9.1-mixtral-1x22b")
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @spaces.GPU(duration=120)
11
  def predict(message, history):
12
+ conv = [{"role": "system", "content": "You are Dolphin, a helpful AI assistant."}]
13
+ for item in history:
14
+ conv.append({"role": "user", "content": item[0]})
15
+ conv.append({"role": "assistant", "content": item[1]})
16
+ conv.append({"role": "user", "content": message})
17
+ generated_text = pipe(conv, max_new_tokens=1024)[0]['generated_text'][-1]['content']
18
+ return generated_text
 
 
 
 
 
 
 
 
 
19
 
20
  gr.ChatInterface(predict).launch()