orionweller's picture
Update app.py
953db17 verified
import gradio as gr
import pickle
import numpy as np
import glob
import tqdm
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, set_seed
from peft import PeftModel
import logging
import os
import json
import spaces
import ir_datasets
import pytrec_eval
from huggingface_hub import login
import transformers
import peft
import faiss
import sys
from collections import defaultdict
set_seed(42)
# Set up logging
# Set up logging with time printing
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
# Authenticate with HF_TOKEN
login(token=os.environ['HF_TOKEN'])
# Global variables
CUR_MODEL = "Samaya-AI/Promptriever-Llama2-v1"
BASE_MODEL = "meta-llama/Llama-2-7b-hf"
tokenizer = None
model = None
retrievers = {}
corpus_lookups = {}
queries = {}
q_lookups = {}
qrels = {}
query2qid = {}
datasets = ["scifact"]
current_dataset = "scifact"
faiss_index = None
def log_system_info():
logger.info("System Information:")
logger.info(f"Python version: {sys.version}")
logger.info("\nPackage Versions:")
logger.info(f"torch: {torch.__version__}")
logger.info(f"transformers: {transformers.__version__}")
logger.info(f"peft: {peft.__version__}")
logger.info(f"faiss: {faiss.__version__}")
logger.info(f"gradio: {gr.__version__}")
logger.info(f"ir_datasets: {ir_datasets.__version__}")
if torch.cuda.is_available():
logger.info(f"\nCUDA Information:")
logger.info(f"CUDA available: Yes")
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"cuDNN version: {torch.backends.cudnn.version()}")
logger.info(f"Number of GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
logger.info("\nCUDA Information:")
logger.info("CUDA available: No")
log_system_info()
def pool(last_hidden_states, attention_mask, pool_type="last"):
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
if pool_type == "last":
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
emb = last_hidden[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden.shape[0]
emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths]
else:
raise ValueError(f"pool_type {pool_type} not supported")
return emb
def create_batch_dict(tokenizer, input_texts, always_add_eos="last", max_length=512):
batch_dict = tokenizer(
input_texts,
max_length=max_length - 1,
return_token_type_ids=False,
return_attention_mask=False,
padding=False,
truncation=True
)
if always_add_eos == "last":
batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
return tokenizer.pad(
batch_dict,
padding=True,
pad_to_multiple_of=8,
return_attention_mask=True,
return_tensors="pt",
)
class RepLlamaModel:
def __init__(self, model_name_or_path):
self.base_model = "meta-llama/Llama-2-7b-hf"
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model)
self.tokenizer.model_max_length = 2048
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "right"
self.model = self.get_model(model_name_or_path)
self.model.config.max_length = 2048
def get_model(self, peft_model_name):
base_model = AutoModel.from_pretrained(self.base_model)
model = PeftModel.from_pretrained(base_model, peft_model_name)
model = model.merge_and_unload()
model.eval()
return model
def encode(self, texts, batch_size=48, **kwargs):
# if model is not on cuda, put it there
if self.model.device.type != "cuda":
self.model = self.model.cuda()
all_embeddings = []
for i in tqdm.tqdm(range(0, len(texts), batch_size)):
batch_texts = texts[i:i+batch_size]
batch_dict = create_batch_dict(self.tokenizer, batch_texts, always_add_eos="last")
batch_dict = {key: value.cuda() for key, value in batch_dict.items()}
with torch.cuda.amp.autocast():
with torch.no_grad():
outputs = self.model(**batch_dict)
embeddings = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
embeddings = F.normalize(embeddings, p=2, dim=-1)
logger.info(f"Encoded shape: {embeddings.shape}, Norm of first embedding: {torch.norm(embeddings[0]).item()}")
all_embeddings.append(embeddings.cpu().numpy())
# self.model = self.model.cpu()
return np.concatenate(all_embeddings, axis=0)
def load_corpus_embeddings(dataset_name):
corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
index_files = glob.glob(corpus_path)
index_files.sort(key=lambda x: int(x.split('.')[-2]))
all_embeddings = []
corpus_lookups = []
for file in index_files:
with open(file, 'rb') as f:
embeddings, p_lookup = pickle.load(f)
all_embeddings.append(embeddings)
corpus_lookups.extend(p_lookup)
all_embeddings = np.concatenate(all_embeddings, axis=0)
logger.info(f"Loaded corpus embeddings for {dataset_name}. Shape: {all_embeddings.shape}")
return all_embeddings, corpus_lookups
def create_faiss_index(embeddings):
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
index.add(embeddings)
logger.info(f"Created FAISS index with {index.ntotal} vectors of dimension {dimension}")
return index
def load_or_create_faiss_index(dataset_name):
embeddings, corpus_lookups = load_corpus_embeddings(dataset_name)
index = create_faiss_index(embeddings)
return index, corpus_lookups
def initialize_faiss_and_corpus(dataset_name):
global corpus_lookups
index, corpus_lookups[dataset_name] = load_or_create_faiss_index(dataset_name)
logger.info(f"Initialized FAISS index and corpus lookups for {dataset_name}")
return index
def search_queries(dataset_name, q_reps, depth=100):
global faiss_index
logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
# Perform the search
all_scores, all_indices = faiss_index.search(q_reps, depth)
logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
return all_scores, np.array(psg_indices)
def load_queries(dataset_name):
global queries, q_lookups, qrels, query2qid
dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else ""))
queries[dataset_name] = []
query2qid[dataset_name] = defaultdict(dict)
q_lookups[dataset_name] = {}
qrels[dataset_name] = {}
for query in dataset.queries_iter():
queries[dataset_name].append(query.text)
q_lookups[dataset_name][query.query_id] = query.text
query2qid[dataset_name][query.text] = query.query_id
for qrel in dataset.qrels_iter():
if qrel.query_id not in qrels[dataset_name]:
qrels[dataset_name][qrel.query_id] = {}
qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance
logger.info(f"Loaded queries for {dataset_name}. Total queries: {len(queries[dataset_name])}")
logger.info(f"Loaded qrels for {dataset_name}. Total query IDs: {len(qrels[dataset_name])}")
def evaluate(qrels, results, k_values):
qrels = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in qrels.items()}
results = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in results.items()}
evaluator = pytrec_eval.RelevanceEvaluator(
qrels, {f"ndcg_cut.{k}" for k in k_values} | {f"recall.{k}" for k in k_values}
)
scores = evaluator.evaluate(results)
metrics = {}
for k in k_values:
ndcg_scores = [query_scores[f"ndcg_cut_{k}"] for query_scores in scores.values()]
recall_scores = [query_scores[f"recall_{k}"] for query_scores in scores.values()]
metrics[f"NDCG@{k}"] = round(np.mean(ndcg_scores), 3)
metrics[f"Recall@{k}"] = round(np.mean(recall_scores), 3)
logger.info(f"NDCG@{k}: mean={metrics[f'NDCG@{k}']}, min={min(ndcg_scores)}, max={max(ndcg_scores)}")
logger.info(f"Recall@{k}: mean={metrics[f'Recall@{k}']}, min={min(recall_scores)}, max={max(recall_scores)}")
# delete nDCG@100 and Recall@10
del metrics["NDCG@100"]
del metrics["Recall@100"]
return metrics
@spaces.GPU
def run_evaluation(dataset, postfix):
global current_dataset, queries, model, query2qid
current_dataset = dataset
input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[current_dataset]]
logger.info(f"Number of input texts: {len(input_texts)}")
logger.info(f"Sample input text: {input_texts[0]}")
q_reps = model.encode(input_texts)
logger.info(f"Encoded query first five: {q_reps[0][:5]}")
logger.info(f"Encoded query representations shape: {q_reps.shape}")
all_scores, psg_indices = search_queries(dataset, q_reps)
results = {}
logging.info(f"Number of queries in q_lookups: {len(q_lookups[dataset])}")
logging.info("Size of all_scores: " + str(len(all_scores)))
logging.info("Size of psg_indices: " + str(len(psg_indices)))
for query, scores, doc_ids in zip(queries[current_dataset], all_scores, psg_indices):
qid = query2qid[dataset][query]
qid_str = str(qid)
results[qid_str] = {}
for doc_id, score in zip(doc_ids, scores):
doc_id_str = str(doc_id)
results[qid_str][doc_id_str] = float(score)
if not results[qid_str]: # If no results for this query
logger.warning(f"No results for query {qid_str}")
logger.info(f"Number of queries in results: {len(results)}")
logger.info(f"Sample result: {next(iter(results.items()))}")
qrels[dataset] = {str(qid): {str(doc_id): rel for doc_id, rel in rels.items()}
for qid, rels in qrels[dataset].items()}
logger.info(f"Number of results: {len(results)}")
logger.info(f"Sample result: {list(results.items())[0]}")
logger.info(f"Number of queries in qrels: {len(qrels[dataset])}")
logger.info(f"Sample qrel: {list(qrels[dataset].items())[0]}")
logger.info(f"Number of queries in results: {len(results)}")
logger.info(f"Sample result: {list(results.items())[0]}")
# Check for mismatches
qrels_keys = set(qrels[dataset].keys())
results_keys = set(results.keys())
logger.info(f"Queries in qrels but not in results: {qrels_keys - results_keys}")
logger.info(f"Queries in results but not in qrels: {results_keys - qrels_keys}")
metrics = evaluate(qrels[dataset], results, k_values=[10, 100])
return metrics
@spaces.GPU
def gradio_interface(dataset, postfix):
return run_evaluation(dataset, postfix)
if model is None:
model = RepLlamaModel(model_name_or_path=CUR_MODEL)
load_queries(current_dataset)
faiss_index = initialize_faiss_and_corpus(current_dataset)
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Dropdown(choices=datasets, label="Dataset", value="scifact"),
gr.Textbox(label="Prompt")
],
outputs=gr.JSON(label="Evaluation Results"),
title="Promptriever Demo",
description="Enter a prompt to evaluate the model's performance on SciFact. Note: it takes between **10-30 seconds** to evaluate.",
examples=[
["scifact", ""],
["scifact", "Think carefully about these conditions when determining relevance"]
],
cache_examples=False,
)
# Launch the interface
iface.launch(share=False)