BAAI
/

shunxing1234 commited on
Commit
1db77e2
·
1 Parent(s): 9b97fea

Update README_zh.md

Browse files
Files changed (1) hide show
  1. README_zh.md +21 -8
README_zh.md CHANGED
@@ -39,27 +39,40 @@ license: other
39
  ```python
40
  from transformers import AutoTokenizer, AutoModelForCausalLM
41
  import torch
42
- from cyg_conversation import covert_prompt_to_input_ids_with_history
43
 
44
- tokenizer = AutoTokenizer.from_pretrained("BAAI/AquilaChat-7B")
45
- model = AutoModelForCausalLM.from_pretrained("BAAI/AquilaChat-7B")
 
 
 
46
  model.eval()
47
- model.to("cuda:0")
48
- vocab = tokenizer.vocab
49
- print(len(vocab))
50
 
51
  text = "请给出10个要到北京旅游的理由。"
52
 
53
- tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=512)
54
 
55
- tokens = torch.tensor(tokens)[None,].to("cuda:0")
56
 
57
 
58
  with torch.no_grad():
59
  out = model.generate(tokens, do_sample=True, max_length=512, eos_token_id=100007)[0]
60
 
61
  out = tokenizer.decode(out.cpu().numpy().tolist())
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
63
  print(out)
64
  ```
65
 
 
39
  ```python
40
  from transformers import AutoTokenizer, AutoModelForCausalLM
41
  import torch
 
42
 
43
+ device = torch.device("cuda:1")
44
+
45
+ model_info = "BAAI/AquilaChat-7B"
46
+ tokenizer = AutoTokenizer.from_pretrained(model_info, trust_remote_code=True)
47
+ model = AutoModelForCausalLM.from_pretrained(model_info, trust_remote_code=True)
48
  model.eval()
49
+ model.to(device)
 
 
50
 
51
  text = "请给出10个要到北京旅游的理由。"
52
 
53
+ tokens = tokenizer.encode_plus(text)['input_ids'][:-1]
54
 
55
+ tokens = torch.tensor(tokens)[None,].to(device)
56
 
57
 
58
  with torch.no_grad():
59
  out = model.generate(tokens, do_sample=True, max_length=512, eos_token_id=100007)[0]
60
 
61
  out = tokenizer.decode(out.cpu().numpy().tolist())
62
+ if "###" in out:
63
+ special_index = out.index("###")
64
+ out = out[: special_index]
65
+
66
+ if "[UNK]" in out:
67
+ special_index = out.index("[UNK]")
68
+ out = out[:special_index]
69
+
70
+ if "</s>" in out:
71
+ special_index = out.index("</s>")
72
+ out = out[: special_index]
73
 
74
+ if len(out) > 0 and out[0] == " ":
75
+ out = out[1:]
76
  print(out)
77
  ```
78