pminervini commited on
Commit
5f90f73
1 Parent(s): 1e5558f
Files changed (2) hide show
  1. app.py +49 -46
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
2
  import gradio as gr
3
 
 
 
 
4
  import torch
5
  from transformers import pipeline, StoppingCriteria, StoppingCriteriaList, MaxTimeCriteria, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer, BitsAndBytesConfig
6
  from openai import OpenAI
@@ -56,8 +59,52 @@ def search(query, index="pubmed", num_docs=3):
56
 
57
  return docs
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def analyse(reference: str, passage: str) -> str:
60
- import vllm
61
  fava_input = "Read the following references:\n{evidence}\nPlease identify all the errors in the following text using the information in the references provided and suggest edits if necessary:\n[Text] {output}\n[Edited] "
62
  prompt = [fava_input.format_map({"evidence": reference, "output": passage})]
63
 
@@ -105,51 +152,7 @@ def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/z
105
  }
106
  ]
107
 
108
- for message in messages:
109
- print('MSG', message)
110
-
111
- max_new_tokens = 1024
112
-
113
- if model_name.startswith('openai/'):
114
- openai_model_name = model_name.split('/')[1]
115
-
116
- client = OpenAI()
117
- openai_res = client.chat.completions.create(model=openai_model_name,
118
- messages=messages,
119
- max_tokens=max_new_tokens,
120
- temperature=0)
121
- print('OAI_RESPONSE', openai_res)
122
- response = openai_res.choices[0].message.content.strip()
123
- else:
124
- quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
125
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", low_cpu_mem_usage=True, quantization_config=quantization_config)
126
- tokenizer = AutoTokenizer.from_pretrained(model_name)
127
-
128
- # Load your language model from HuggingFace Transformers
129
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
130
-
131
- tokenized_prompt = tokenizer.apply_chat_template(messages, tokenize=True)
132
-
133
- # Define the stopping criteria using MaxTimeCriteria
134
- stopping_criteria = StoppingCriteriaList([
135
- # MaxTimeCriteria(32),
136
- MultiTokenEOSCriteria("\n", tokenizer, len(tokenized_prompt))
137
- ])
138
-
139
- # Define the generation_kwargs with stopping criteria
140
- generation_kwargs = {
141
- "max_new_tokens": max_new_tokens,
142
- "generation_kwargs": {"stopping_criteria": stopping_criteria},
143
- "return_full_text": False
144
- }
145
-
146
- # Generate response using the HF LLM
147
- hf_response = generator(messages, **generation_kwargs)
148
-
149
- print('HF_RESPONSE', hf_response)
150
- response = hf_response[0]['generated_text']
151
-
152
- model = tokenizer = None
153
 
154
  # analysed_response = analyse(joined_docs, response)
155
 
 
1
  import os
2
  import gradio as gr
3
 
4
+ import ray
5
+ import vllm
6
+
7
  import torch
8
  from transformers import pipeline, StoppingCriteria, StoppingCriteriaList, MaxTimeCriteria, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer, BitsAndBytesConfig
9
  from openai import OpenAI
 
59
 
60
  return docs
61
 
62
+ @ray.remote(num_gpus=1, max_calls=1)
63
+ def generate(model_name: str, messages):
64
+ max_new_tokens = 1024
65
+
66
+ if model_name.startswith('openai/'):
67
+ openai_model_name = model_name.split('/')[1]
68
+
69
+ client = OpenAI()
70
+ openai_res = client.chat.completions.create(model=openai_model_name,
71
+ messages=messages,
72
+ max_tokens=max_new_tokens,
73
+ temperature=0)
74
+ print('OAI_RESPONSE', openai_res)
75
+ response = openai_res.choices[0].message.content.strip()
76
+ else:
77
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
78
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", low_cpu_mem_usage=True, quantization_config=quantization_config)
79
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
80
+
81
+ # Load your language model from HuggingFace Transformers
82
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
83
+
84
+ tokenized_prompt = tokenizer.apply_chat_template(messages, tokenize=True)
85
+
86
+ # Define the stopping criteria using MaxTimeCriteria
87
+ stopping_criteria = StoppingCriteriaList([
88
+ # MaxTimeCriteria(32),
89
+ MultiTokenEOSCriteria("\n", tokenizer, len(tokenized_prompt))
90
+ ])
91
+
92
+ # Define the generation_kwargs with stopping criteria
93
+ generation_kwargs = {
94
+ "max_new_tokens": max_new_tokens,
95
+ "generation_kwargs": {"stopping_criteria": stopping_criteria},
96
+ "return_full_text": False
97
+ }
98
+
99
+ # Generate response using the HF LLM
100
+ hf_response = generator(messages, **generation_kwargs)
101
+
102
+ print('HF_RESPONSE', hf_response)
103
+ response = hf_response[0]['generated_text']
104
+ return response
105
+
106
+ @ray.remote(num_gpus=1, max_calls=1)
107
  def analyse(reference: str, passage: str) -> str:
 
108
  fava_input = "Read the following references:\n{evidence}\nPlease identify all the errors in the following text using the information in the references provided and suggest edits if necessary:\n[Text] {output}\n[Edited] "
109
  prompt = [fava_input.format_map({"evidence": reference, "output": passage})]
110
 
 
152
  }
153
  ]
154
 
155
+ response = generate(model_name, messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # analysed_response = analyse(joined_docs, response)
158
 
requirements.txt CHANGED
@@ -5,3 +5,4 @@ transformers
5
  elasticsearch
6
  openai
7
  vllm
 
 
5
  elasticsearch
6
  openai
7
  vllm
8
+ ray