import os import time from datetime import datetime import logging from pathlib import Path import requests import json import numpy as np import pandas as pd import spacy from sentence_transformers import CrossEncoder import litellm from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForTokenClassification import torch import cohere from openai import OpenAI import anthropic import replicate # pip install -U google-generativeai import google.generativeai as genai from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage import src.backend.util as util import src.envs as envs litellm.set_verbose=True # Set up basic configuration for logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Load spacy model for word tokenization nlp = spacy.load("en_core_web_sm") os.environ["HUGGINGFACE_API_KEY"] = envs.TOKEN class ModelLoadingException(Exception): """Exception raised for errors in loading a model. Attributes: model_id (str): The model identifier. revision (str): The model revision. """ def __init__(self, model_id, revision, messages="Error initializing model"): self.model_id = model_id self.revision = revision super().__init__(f"{messages} id={model_id} revision={revision}") class SummaryGenerator: """A class to generate summaries using a causal language model. Attributes: model (str): huggingface/{model_id} api_base (str): https://api-inference.huggingface.co/models/{model_id} summaries_df (DataFrame): DataFrame to store generated summaries. revision (str): Model revision. avg_length (float): Average length of summaries. answer_rate (float): Rate of non-empty summaries. """ def __init__(self, model_id, revision, device): """ Initializes the SummaryGenerator with a model. Args: model_id (str): Identifier for the model. revision (str): Revision of the model. """ self.model_id = model_id self.model = f"huggingface/{model_id}" self.api_base = f"https://api-inference.huggingface.co/models/{model_id}" self.summaries_df = pd.DataFrame() self.revision = revision self.device = device self.avg_length = None self.answer_rate = None self.exceptions = None self.local_model = None self.local_pipeline = None def generate_summaries(self, df, save_path=None): """Generate summaries for a given DataFrame of source docs. Args: df (DataFrame): DataFrame containing source docs. Returns: summaries_df (DataFrame): Generated summaries by the model. """ exceptions = [] if (save_path is not None) and os.path.exists(save_path): self.summaries_df = pd.read_csv(save_path) print(f'Loaded generated summaries from {save_path}') else: source, summary, dataset = [], [], [] print(f"Total: {df.shape[0]}") for index, row in tqdm(df.iterrows(), total=df.shape[0]): _source = row['text'] _dataset = row['dataset'] system_prompt = envs.SYSTEM_PROMPT user_prompt = f"{envs.USER_PROMPT}\nPassage:\n{_source}" _summary = None while not _summary: try: _summary = self.generate_summary(system_prompt, user_prompt) # print(f"Finish index {index}") break except Exception as e: if 'Rate limit reached' in str(e): wait_time = 300 current_time = datetime.now().strftime('%H:%M:%S') print(f"Rate limit hit at {current_time}. Waiting for 5 minutes before retrying...") time.sleep(wait_time) elif 'is currently loading' in str(e): wait_time = 200 print(f"Model is loading, wait for {wait_time}") time.sleep(wait_time) elif '429' in str(e): # for gemini models wait_time = 60 print(f"Quota has reached, wait for {wait_time}") time.sleep(wait_time) else: print(f"Error at index {index}: {e}") _summary = "" exceptions.append(index) break summary.append(_summary) source.append(_source) dataset.append(_dataset) # Sleep to prevent hitting rate limits too frequently time.sleep(1) self.summaries_df = pd.DataFrame(list(zip(source, summary, dataset)), columns=["source", "summary", "dataset"]) if save_path is not None: print(f'Save summaries to {save_path}') fpath = Path(save_path) fpath.parent.mkdir(parents=True, exist_ok=True) self.summaries_df.to_csv(fpath) self.exceptions = exceptions self._compute_avg_length() self._compute_answer_rate() return self.summaries_df def generate_summary(self, system_prompt: str, user_prompt: str): # Using Together AI API using_together_api = False together_ai_api_models = ['mixtral', 'dbrx', 'wizardlm', 'llama-3-', 'qwen', 'zero-one-ai'] #, 'mistralai' using_replicate_api = False replicate_api_models = ['snowflake', 'llama-3.1-405b'] using_pipeline = False pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b'] for replicate_api_model in replicate_api_models: if replicate_api_model in self.model_id.lower(): using_replicate_api = True break if not using_replicate_api: for together_ai_api_model in together_ai_api_models: if together_ai_api_model in self.model_id.lower(): using_together_api = True break if not using_replicate_api and not using_together_api: for pipeline_model in pipeline_models: if pipeline_model in self.model_id.lower(): using_pipeline = True break # if 'mixtral' in self.model_id.lower() or 'dbrx' in self.model_id.lower() or 'wizardlm' in self.model_id.lower(): # For mixtral and dbrx models, use Together AI API if using_together_api: # print('using together api') # suffix = "completions" if ('mixtral' in self.model_id.lower() or 'base' in self.model_id.lower()) else "chat/completions" suffix = "chat/completions" url = f"https://api.together.xyz/v1/{suffix}" payload = { "model": self.model_id, 'max_new_tokens': 250, "temperature": 0.0, } payload['messages'] = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] headers = { "accept": "application/json", "content-type": "application/json", "Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}" } response = requests.post(url, json=payload, headers=headers) print(response) try: result = json.loads(response.text) # print(result) result = result["choices"][0] if 'message' in result: result = result["message"]["content"].strip() else: result = result["text"] result_candidates = [result_cancdidate for result_cancdidate in result.split('\n\n') if len(result_cancdidate) > 0] result = result_candidates[0] # print(result) except: # print(response) result = '' print(result) return result # Using OpenAI API elif 'gpt' in self.model_id.lower(): client = OpenAI() response = client.chat.completions.create( model=self.model_id.replace('openai/',''), messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], temperature=0.0, max_tokens=250, ) # print(response) result = response.choices[0].message.content print(result) return result # Using Google AI API for Gemini models elif 'gemini' in self.model_id.lower(): genai.configure(api_key=os.getenv('GOOGLE_AI_API_KEY')) generation_config = { "temperature": 0, "top_p": 0.95, # cannot change "top_k": 0, "max_output_tokens": 250, # "response_mime_type": "application/json", } safety_settings = [ { "category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE" }, { "category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE" }, { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE" }, { "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE" }, ] model = genai.GenerativeModel(model_name=self.model_id.lower().split('google/')[-1], generation_config=generation_config, system_instruction=system_prompt, safety_settings=safety_settings) # print(model) convo = model.start_chat(history=[]) convo.send_message(user_prompt) # print(convo.last) result = convo.last.text print(result) return result elif using_replicate_api: print("using replicate") if 'snowflake' in self.model_id.lower(): input = { "prompt": user_prompt, "temperature": 0, "max_new_tokens": 250, "stop_sequences": "<|im_end|>", "prompt_template": f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + "<|im_start|>user\n{prompt}<|im_end|>\n\n<|im_start|>assistant\n", } else: input = { "prompt": user_prompt, "system_prompt": system_prompt, "temperature": 0, "max_new_tokens": 250 } response = replicate.run( self.model_id, input=input ) # print(response) if isinstance(response, list): response = ''.join(response) # print(response) # print() print(response) return response elif 'claude' in self.model_id.lower(): # using anthropic api client = anthropic.Anthropic() message = client.messages.create( model=self.model_id.split('/')[-1], max_tokens=250, temperature=0, system=system_prompt, messages=[ { "role": "user", "content": [ { "type": "text", "text": user_prompt } ] } ] ) result = message.content[0].text print(result) return result elif 'mistral-large' in self.model_id.lower(): api_key = os.environ["MISTRAL_API_KEY"] client = MistralClient(api_key=api_key) messages = [ ChatMessage(role="system", content=system_prompt), ChatMessage(role="user", content=user_prompt) ] # No streaming chat_response = client.chat( model=self.model_id, messages=messages, ) result = chat_response.choices[0].message.content print(result) return result # Using HF API or download checkpoints elif self.local_model is None and self.local_pipeline is None: # try: # try use HuggingFace API # print('** using huggingface api') # response = litellm.completion( # model=self.model, # messages=[{"role": "system", "content": system_prompt}, # {"role": "user", "content": user_prompt}], # temperature=0.0, # max_tokens=250, # api_base=self.api_base, # ) # result = response['choices'][0]['message']['content'] # result = result.split('<|im_end|>')[0] # print(result) # return result # except Exception as e: # if 'Rate limit reached' in str(e) : # wait_time = 300 # current_time = datetime.now().strftime('%H:%M:%S') # print(f"Rate limit hit at {current_time}. Waiting for 5 minutes before retrying...") # time.sleep(wait_time) # else: if using_pipeline: self.local_pipeline = pipeline( "text-generation", model=self.model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", ) else: self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True) print("Tokenizer loaded") self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto") print(self.local_model.device) print("Local model loaded") # Using local model/pipeline if self.local_pipeline: print('Using Transformers pipeline') messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] outputs = self.local_pipeline( messages, max_new_tokens=250, ) result = outputs[0]["generated_text"][-1]['content'] print(result) return result elif self.local_model: # cannot call API. using local model / pipeline print('Using local model') if 'gemma' in self.model_id.lower() or 'mistral-7b' in self.model_id.lower(): messages=[ # gemma-1.1, mistral-7b does not accept system role {"role": "user", "content": system_prompt + ' ' + user_prompt} ] prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False) elif 'phi-2' in self.model_id.lower(): prompt = system_prompt + '\n' + user_prompt elif 'intel' in self.model_id.lower(): prompt = f"### System:\n{system_prompt}\n### User:\n{user_prompt}\n### Assistant:\n" else: messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False) # print(prompt) # print('-'*50) input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01, pad_token_id=self.tokenizer.eos_token_id) if 'glm' in self.model_id.lower(): outputs = outputs[:, input_ids['input_ids'].shape[1]:] result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) if 'gemma-2' in self.model_id.lower(): result = result.split(user_prompt + '\nmodel')[-1].strip() elif 'intel' in self.model_id.lower(): result = result.split("### Assistant:\n")[-1] else: print(prompt) print('-'*50) result = result.replace(prompt.strip(), '') print(result) return result def _compute_avg_length(self): """ Compute the average length of non-empty summaries using SpaCy. """ total_word_count = 0 total_count = 0 for summary in self.summaries_df['summary']: if util.is_summary_valid(summary): doc = nlp(summary) words = [token.text for token in doc if token.is_alpha] total_word_count += len(words) total_count += 1 self.avg_length = 0 if total_count == 0 else total_word_count / total_count def _compute_answer_rate(self): """ Compute the rate of non-empty summaries. """ valid_count = sum(1 for summary in self.summaries_df['summary'] if util.is_summary_valid(summary)) total_count = len(self.summaries_df) self.answer_rate = 0 if total_count == 0 else valid_count / total_count class EvaluationModel: """A class to evaluate generated summaries. Attributes: model (CrossEncoder): The evaluation model. scores (list): List of evaluation scores. accuracy (float): Accuracy of the summaries. hallucination_rate (float): Rate of hallucination in summaries. """ def __init__(self, model_path, device): """ Initializes the EvaluationModel with a CrossEncoder model. Args: model_path (str): Path to the CrossEncoder model. """ self.model = AutoModelForTokenClassification.from_pretrained(model_path) self.device = device self.model.to(self.device) self.scores = [] self.factual_consistency_rate = None self.hallucination_rate = None def predict(self, text_pairs): """Load LoRA adapters of HHEM and make predictions All HHEM 2.1 settings, e.g., prompt template, are hardcoded in this function. Args: text_pairs: list of tuples, each tuple contains two strings (premise, hypothesis) checkpoint: model ID on Hugging Face """ prompt = " Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}" tokenizer = AutoTokenizer.from_pretrained('t5-base') inputs = tokenizer( [prompt.format(text1=pair[0], text2=pair[1]) for pair in text_pairs], return_tensors='pt', padding='longest').to(self.device) self.model.eval() with torch.no_grad(): output = self.model(**inputs) logits = output.logits logits = logits[:,0,:] # get the logits on the first token logits = torch.softmax(logits, dim=-1) scores = [round(x, 5) for x in logits[:, 1].tolist()] # list of float return scores def evaluate_hallucination(self, summaries_df): """ Evaluate the hallucination rate in summaries. Updates the 'scores' attribute of the instance with the computed scores. Args: summaries_df (DataFrame): DataFrame containing source docs and summaries. Returns: list: List of hallucination scores. Also updates the 'scores' attribute of the instance. """ hem_scores = [] sources = [] summaries = [] source_summary_pairs = util.create_pairs(summaries_df) for doc, summary in source_summary_pairs: if util.is_summary_valid(summary): try: summary = summary.replace('','').replace('','').strip() score = self.predict([(doc, summary)])[0] # print(score) # if score < 0.5: # print(doc) # print('-'*10) # print(summary) # print('='*20) hem_scores.append(score) sources.append(doc) summaries.append(summary) except Exception as e: logging.error(f"Error while running HEM: {e}") raise self.scores = hem_scores eval_results = {'source': sources, 'summary': summaries, 'HEM scores': hem_scores} return hem_scores, eval_results def compute_factual_consistency_rate(self, threshold=0.5): """ Compute the factual consistency rate of the evaluated summaries based on the previously calculated scores. This method relies on the 'scores' attribute being populated, typically via the 'evaluate_hallucination' method. Returns: float: Factual Consistency Rate. Also updates the 'factual_consistency_rate' and 'hallucination_rate' attributes of the instance. Raises: ValueError: If scores have not been calculated prior to calling this method. """ if not self.scores: error_msg = "Scores not calculated. Call evaluate_hallucination() first." logging.error(error_msg) raise ValueError(error_msg) # Use threshold of 0.5 to compute factual_consistency_rate num_above_threshold = sum(score >= threshold for score in self.scores) num_total = len(self.scores) if not num_total: raise ValueError("No scores available to compute factual consistency rate.") self.factual_consistency_rate = (num_above_threshold / num_total) * 100 self.hallucination_rate = 100 - self.factual_consistency_rate return self.factual_consistency_rate