zihanliu commited on
Commit
2347584
1 Parent(s): 115aae6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -9
README.md CHANGED
@@ -16,7 +16,7 @@ We release ChatQA1.5, which excels at RAG-based conversational question answerin
16
 
17
 
18
  ## Benchmark Results
19
- Results in ConvRAG are as follows:
20
 
21
  | | ChatQA-1.0-7B | Command-R-Plus | Llama-3-instruct-70b | GPT-4-0613 | ChatQA-1.0-70B | ChatQA-1.5-8B | ChatQA-1.5-70B |
22
  | -- |:--:|:--:|:--:|:--:|:--:|:--:|:--:|
@@ -33,7 +33,7 @@ Results in ConvRAG are as follows:
33
  | Average (all) | 47.71 | 50.93 | 52.52 | 53.90 | 54.14 | 55.17 | 58.25 |
34
  | Average (exclude HybriDial) | 46.96 | 51.40 | 52.95 | 54.35 | 53.89 | 53.99 | 57.14 |
35
 
36
- Note that ChatQA-1.5 used some samples from the HybriDial training dataset. To ensure fair comparison, we also compare average scores excluding HybriDial. The data and evaluation scripts for ConvRAG can be found here.
37
 
38
 
39
  ## Prompt Format
@@ -72,14 +72,14 @@ def get_formatted_input(messages, context):
72
  system = "System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context."
73
  instruction = "Please give a full and complete answer for the question."
74
 
75
- for item in enumerate(messages):
76
  if item['role'] == "user":
77
  ## only apply this instruction for the first user turn
78
  item['content'] = instruction + " " + item['content']
79
  break
80
 
81
  conversation = ""
82
- for item in turn_list:
83
  if item["role"] == "user":
84
  conversation += "User: " + item["content"] + "\n\n"
85
  else:
@@ -90,17 +90,14 @@ def get_formatted_input(messages, context):
90
  return formatted_input
91
 
92
  formatted_input = get_formatted_input(messages, context)
93
- input_ids = tokenizer(tokenizer.bos_token + formatted_input, return_tensors="pt").to(model.device)
94
 
95
  terminators = [
96
  tokenizer.eos_token_id,
97
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
98
  ]
99
 
100
- outputs = model.generate(
101
- input_ids,
102
- max_new_tokens=128,
103
- eos_token_id=terminators)
104
 
105
  response = outputs[0][input_ids.shape[-1]:]
106
  print(tokenizer.decode(response, skip_special_tokens=True))
 
16
 
17
 
18
  ## Benchmark Results
19
+ Results in ConvRAG Bench are as follows:
20
 
21
  | | ChatQA-1.0-7B | Command-R-Plus | Llama-3-instruct-70b | GPT-4-0613 | ChatQA-1.0-70B | ChatQA-1.5-8B | ChatQA-1.5-70B |
22
  | -- |:--:|:--:|:--:|:--:|:--:|:--:|:--:|
 
33
  | Average (all) | 47.71 | 50.93 | 52.52 | 53.90 | 54.14 | 55.17 | 58.25 |
34
  | Average (exclude HybriDial) | 46.96 | 51.40 | 52.95 | 54.35 | 53.89 | 53.99 | 57.14 |
35
 
36
+ Note that ChatQA-1.5 used some samples from the HybriDial training dataset. To ensure fair comparison, we also compare average scores excluding HybriDial. The data and evaluation scripts for ConvRAG can be found [here](https://huggingface.co/datasets/nvidia/ConvRAG-Bench).
37
 
38
 
39
  ## Prompt Format
 
72
  system = "System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context."
73
  instruction = "Please give a full and complete answer for the question."
74
 
75
+ for item in messages:
76
  if item['role'] == "user":
77
  ## only apply this instruction for the first user turn
78
  item['content'] = instruction + " " + item['content']
79
  break
80
 
81
  conversation = ""
82
+ for item in messages:
83
  if item["role"] == "user":
84
  conversation += "User: " + item["content"] + "\n\n"
85
  else:
 
90
  return formatted_input
91
 
92
  formatted_input = get_formatted_input(messages, context)
93
+ tokenized_prompt = tokenizer(tokenizer.bos_token + formatted_input, return_tensors="pt").to(model.device)
94
 
95
  terminators = [
96
  tokenizer.eos_token_id,
97
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
98
  ]
99
 
100
+ outputs = model.generate(input_ids=tokenized_prompt.input_ids, attention_mask=tokenized_prompt.attention_mask, max_new_tokens=128, eos_token_id=terminators)
 
 
 
101
 
102
  response = outputs[0][input_ids.shape[-1]:]
103
  print(tokenizer.decode(response, skip_special_tokens=True))