szzzzz commited on
Commit
54d5818
·
1 Parent(s): c6a7002

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -6
app.py CHANGED
@@ -1,8 +1,60 @@
1
  import gradio as gr
2
- import chatbot
3
- print(chatbot.__version__)
4
- model = chatbot.Bot()
5
- model.load("szzzzz/chatbot_bloom_560m",low_disk_usage=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def add_text(history, text):
@@ -19,7 +71,7 @@ def bot(history):
19
  else:
20
  prompt = prompt + "\nAssistant: "
21
 
22
- response = model.generate(prompt)
23
  history[-1][1] = response
24
  return history
25
 
@@ -32,7 +84,7 @@ def regenerate(history):
32
  else:
33
  prompt = prompt + "\nAssistant: "
34
 
35
- response = model.generate(prompt)
36
  history[-1][1] = response
37
  return history
38
 
 
1
  import gradio as gr
2
+ import torch
3
+ import requests
4
+ from transformers import BloomForCausalLM, BloomTokenizerFast
5
+ import os
6
+
7
+ repo_id = 'szzzzz/chatbot_bloom_560m'
8
+ os.mkdir('./chatbot')
9
+ path = huggingface_hub.snapshot_download(
10
+ repo_id=repo_id, cache_dir='./chatbot',ignore_patterns = "*bin"
11
+ )
12
+ url = huggingface_hub.file_download.hf_hub_url(repo_id, "pytorch_model.bin")
13
+ tokenizer = BloomTokenizerFast.from_pretrained(path)
14
+ state_dict = torch.load(
15
+ io.BytesIO(requests.get(url).content), map_location=torch.device("cpu")
16
+ )
17
+ model = BloomForCausalLM.from_pretrained(
18
+ pretrained_model_name_or_path=None,
19
+ state_dict=state_dict,
20
+ config=f"{path}/config.json",
21
+ )
22
+ max_length=1024
23
+
24
+
25
+ def generate(inputs: str) -> str:
26
+ """generate content on inputs .
27
+
28
+ Args:
29
+ inputs (str):
30
+ example :'Human: 你好 .\n \nAssistant: '
31
+
32
+ Returns:
33
+ str:
34
+ bot response
35
+ example : '你好!我是你的ai助手!'
36
+
37
+ """
38
+ input_text = tokenizer.bos_token + inputs
39
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
40
+ _, input_len = input_ids.shape
41
+ if input_len >= max_length - 4:
42
+ res = "对话超过字数限制,请重新开始."
43
+ return res
44
+ pred_ids = model.generate(
45
+ input_ids,
46
+ eos_token_id=tokenizer.eos_token_id,
47
+ pad_token_id=tokenizer.pad_token_id,
48
+ bos_token_id=tokenizer.bos_token_id,
49
+ do_sample=True,
50
+ temperature=0.6,
51
+ top_p=0.8,
52
+ max_new_tokens=max_length - input_len,
53
+ repetition_penalty=1.2,
54
+ )
55
+ pred = pred_ids[0][input_len:]
56
+ res = tokenizer.decode(pred, skip_special_tokens=True)
57
+ return res
58
 
59
 
60
  def add_text(history, text):
 
71
  else:
72
  prompt = prompt + "\nAssistant: "
73
 
74
+ response = generate(prompt)
75
  history[-1][1] = response
76
  return history
77
 
 
84
  else:
85
  prompt = prompt + "\nAssistant: "
86
 
87
+ response = generate(prompt)
88
  history[-1][1] = response
89
  return history
90