safufu commited on
Commit
aaa92a8
1 Parent(s): c227e3c

Update README.md

Browse files

using pipeline to be more friendly

Files changed (1) hide show
  1. README.md +10 -15
README.md CHANGED
@@ -26,19 +26,21 @@ license_link: LICENSE
26
  可通过以下代码加载 Index-1.9B-Chat 模型来进行对话:
27
 
28
  ```python
29
- import torch
30
  import argparse
31
- from transformers import AutoTokenizer, AutoModelForCausalLM
32
 
33
  # 注意!目录不能含有".",可以替换成"_"
34
  parser = argparse.ArgumentParser()
35
- parser.add_argument('--model_path', default="", type=str, help="")
 
36
  args = parser.parse_args()
37
 
38
  tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
39
- model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map='auto')
40
- model = model.eval()
41
- print('model loaded', args.model_path, model.device)
 
 
42
 
43
  system_message = "你是由哔哩哔哩自主研发的大语言模型,名为“Index”。你能够根据用户传入的信息,帮助用户完成指定的任务,并生成恰当的、符合要求的回复。"
44
  query = "续写 天不生我金坷垃"
@@ -46,15 +48,8 @@ model_input = []
46
  model_input.append({"role": "system", "content": system_message})
47
  model_input.append({"role": "user", "content": query})
48
 
49
- inputs = tokenizer.apply_chat_template(model_input, tokenize=False, add_generation_prompt=False)
50
- input_ids = tokenizer.encode(inputs, return_tensors="pt").to(model.device)
51
- history_outputs = model.generate(input_ids, max_new_tokens=300, top_k=5, top_p=0.8, temperature=0.3, repetition_penalty=1.1, do_sample=True)
52
-
53
- # 删除</s>
54
- if history_outputs[0][-1] == 2:
55
- history_outputs = history_outputs[:, :-1]
56
 
57
- outputs = history_outputs[0][len(input_ids[0]):]
58
  print('User:', query)
59
- print('\nModel:', tokenizer.decode(outputs))
60
  ```
 
26
  可通过以下代码加载 Index-1.9B-Chat 模型来进行对话:
27
 
28
  ```python
 
29
  import argparse
30
+ from transformers import AutoTokenizer, pipeline
31
 
32
  # 注意!目录不能含有".",可以替换成"_"
33
  parser = argparse.ArgumentParser()
34
+ parser.add_argument('--model_path', default="IndexTeam/Index-1.9B-Chat", type=str, help="")
35
+ parser.add_argument('--device', default="cpu", type=str, help="") # also could be "cuda" or "mps" for Apple silicon
36
  args = parser.parse_args()
37
 
38
  tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
39
+ generator = pipeline("text-generation",
40
+ model=args.model_path,
41
+ tokenizer=tokenizer, trust_remote_code=True,
42
+ device=args.device)
43
+
44
 
45
  system_message = "你是由哔哩哔哩自主研发的大语言模型,名为“Index”。你能够根据用户传入的信息,帮助用户完成指定的任务,并生成恰当的、符合要求的回复。"
46
  query = "续写 天不生我金坷垃"
 
48
  model_input.append({"role": "system", "content": system_message})
49
  model_input.append({"role": "user", "content": query})
50
 
51
+ model_output = generator(model_input, max_new_tokens=300, top_k=5, top_p=0.8, temperature=0.3, repetition_penalty=1.1, do_sample=True)
 
 
 
 
 
 
52
 
 
53
  print('User:', query)
54
+ print('Model:', model_output)
55
  ```