import os
import gradio as gr
import ray
import vllm
import torch
from transformers import pipeline, StoppingCriteria, StoppingCriteriaList, MaxTimeCriteria, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer, BitsAndBytesConfig
from openai import OpenAI
from elasticsearch import Elasticsearch
class MultiTokenEOSCriteria(StoppingCriteria):
def __init__(self, sequence: str, tokenizer: PreTrainedTokenizer, initial_decoder_input_length: int, batch_size: int = 1) -> None:
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids) + 2
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def search(query, index="pubmed", num_docs=3):
"""
Search the Elasticsearch index for the most relevant documents.
"""
docs = []
if num_docs > 0:
print(f'Running query: {query}')
es_request_body = {
"query": {
"match": {
"content": query # Assuming documents have a 'content' field
}
}, "size": num_docs
}
# Connect to Elasticsearch
es = Elasticsearch(hosts=["https://data.neuralnoise.com:9200"],
basic_auth=('elastic', os.environ['ES_PASSWORD']),
verify_certs=False, ssl_show_warn=False)
response = es.options(request_timeout=60).search(index=index, body=es_request_body)
# Extract and return the documents
docs = [hit["_source"]["content"] for hit in response['hits']['hits']]
print(f'Received {len(docs)} documents from index {index}')
return docs
@ray.remote(num_gpus=1, max_calls=1)
def generate(model_name: str, messages):
max_new_tokens = 1024
if model_name.startswith('openai/'):
openai_model_name = model_name.split('/')[1]
client = OpenAI()
openai_res = client.chat.completions.create(model=openai_model_name,
messages=messages,
max_tokens=max_new_tokens,
temperature=0)
print('OAI_RESPONSE', openai_res)
response = openai_res.choices[0].message.content.strip()
else:
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", low_cpu_mem_usage=True, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load your language model from HuggingFace Transformers
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
tokenized_prompt = tokenizer.apply_chat_template(messages, tokenize=True)
# Define the stopping criteria using MaxTimeCriteria
stopping_criteria = StoppingCriteriaList([
# MaxTimeCriteria(32),
MultiTokenEOSCriteria("\n", tokenizer, len(tokenized_prompt))
])
# Define the generation_kwargs with stopping criteria
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"generation_kwargs": {"stopping_criteria": stopping_criteria},
"return_full_text": False
}
# Generate response using the HF LLM
hf_response = generator(messages, **generation_kwargs)
print('HF_RESPONSE', hf_response)
response = hf_response[0]['generated_text']
return response
@ray.remote(num_gpus=1, max_calls=1)
def analyse(reference: str, passage: str) -> str:
fava_input = "Read the following references:\n{evidence}\nPlease identify all the errors in the following text using the information in the references provided and suggest edits if necessary:\n[Text] {output}\n[Edited] "
prompt = [fava_input.format_map({"evidence": reference, "output": passage})]
model = vllm.LLM(model="fava-uw/fava-model")
sampling_params = vllm.SamplingParams(temperature=0, top_p=1.0, max_tokens=500)
outputs = model.generate(prompt, sampling_params)
outputs = [it.outputs[0].text for it in outputs]
output = outputs[0].replace("", " ")
output = output.replace("", " ")
output = output.replace("