orionweller's picture
init
a907241
raw
history blame
No virus
5.96 kB
import gradio as gr
import pickle
import numpy as np
import glob
from tqdm import tqdm
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
from tevatron.retriever.searcher import FaissFlatSearcher
import logging
import os
import json
import spaces
import ir_datasets
import subprocess
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
CUR_MODEL = "orionweller/repllama-instruct-hard-positives-v2-joint"
base_model = "meta-llama/Llama-2-7b-hf"
tokenizer = None
model = None
retriever = None
corpus_lookup = None
queries = None
q_lookup = None
def load_model():
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
base_model_instance = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
model = PeftModel.from_pretrained(base_model_instance, CUR_MODEL)
model = model.merge_and_unload()
model.eval()
model.cuda()
def load_corpus_embeddings(dataset_name):
global retriever, corpus_lookup
corpus_path = f"{dataset_name}/corpus_emb*"
index_files = glob.glob(corpus_path)
logger.info(f'Pattern match found {len(index_files)} files; loading them into index.')
p_reps_0, p_lookup_0 = pickle_load(index_files[0])
retriever = FaissFlatSearcher(p_reps_0)
shards = [(p_reps_0, p_lookup_0)] + [pickle_load(f) for f in index_files[1:]]
corpus_lookup = []
for p_reps, p_lookup in tqdm(shards, desc='Loading shards into index', total=len(index_files)):
retriever.add(p_reps)
corpus_lookup += p_lookup
def pickle_load(path):
with open(path, 'rb') as f:
reps, lookup = pickle.load(f)
return np.array(reps), lookup
def load_queries(dataset_name):
global queries, q_lookup
dataset = ir_datasets.load(f"beir/{dataset_name.lower()}/test")
queries = []
q_lookup = {}
for query in dataset.queries_iter():
queries.append(query.text)
q_lookup[query.query_id] = query.text
def encode_queries(prefix, postfix):
global queries
input_texts = [f"{prefix}Query: {query} {postfix}".strip() for query in queries]
encoded_embeds = []
batch_size = 32 # Adjust as needed
for start_idx in range(0, len(input_texts), batch_size):
batch_input_texts = input_texts[start_idx: start_idx + batch_size]
inputs = tokenizer(batch_input_texts, padding=True, truncation=True, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
embeds = outputs.last_hidden_state[:, 0, :] # Use [CLS] token embedding
embeds = F.normalize(embeds, p=2, dim=-1)
encoded_embeds.append(embeds.cpu().numpy())
return np.concatenate(encoded_embeds, axis=0)
def search_queries(q_reps, depth=1000):
all_scores, all_indices = retriever.search(q_reps, depth)
psg_indices = [[str(corpus_lookup[x]) for x in q_dd] for q_dd in all_indices]
return all_scores, np.array(psg_indices)
def write_ranking(corpus_indices, corpus_scores, ranking_save_file):
with open(ranking_save_file, 'w') as f:
for qid, q_doc_scores, q_doc_indices in zip(q_lookup.keys(), corpus_scores, corpus_indices):
score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)]
score_list = sorted(score_list, key=lambda x: x[0], reverse=True)
for rank, (s, idx) in enumerate(score_list, 1):
f.write(f'{qid} Q0 {idx} {rank} {s} pyserini\n')
def evaluate_with_subprocess(dataset, ranking_file):
# Convert to TREC format
trec_file = f"rank.{dataset}.trec"
convert_cmd = [
"python", "-m", "tevatron.utils.format.convert_result_to_trec",
"--input", ranking_file,
"--output", trec_file,
"--remove_query"
]
subprocess.run(convert_cmd, check=True)
# Evaluate using trec_eval
eval_cmd = [
"python", "-m", "pyserini.eval.trec_eval",
"-c", "-mrecall.100", "-mndcg_cut.10",
f"beir-v1.0.0-{dataset}-test", trec_file
]
result = subprocess.run(eval_cmd, capture_output=True, text=True, check=True)
# Parse the output
lines = result.stdout.strip().split('\n')
ndcg_10 = float(lines[0].split()[-1])
recall_100 = float(lines[1].split()[-1])
# Clean up temporary files
os.remove(ranking_file)
os.remove(trec_file)
return f"nDCG@10: {ndcg_10:.4f}, Recall@100: {recall_100:.4f}"
@spaces.GPU
def run_evaluation(dataset, prefix, postfix):
global queries, q_lookup
# Load corpus embeddings and queries if not already loaded
if retriever is None or queries is None:
load_corpus_embeddings(dataset)
load_queries(dataset)
# Encode queries
q_reps = encode_queries(prefix, postfix)
# Search
all_scores, psg_indices = search_queries(q_reps)
# Write ranking
ranking_file = f"temp_ranking_{dataset}.txt"
write_ranking(psg_indices, all_scores, ranking_file)
# Evaluate
results = evaluate_with_subprocess(dataset, ranking_file)
return results
def gradio_interface(dataset, prefix, postfix):
return run_evaluation(dataset, prefix, postfix)
# Load model
load_model()
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Dropdown(choices=["scifact", "arguana"], label="Dataset"),
gr.Textbox(label="Prefix prompt"),
gr.Textbox(label="Postfix prompt")
],
outputs=gr.Textbox(label="Evaluation Results"),
title="Query Evaluation with Custom Prompts",
description="Select a dataset and enter prefix and postfix prompts to evaluate queries using Pyserini."
)
# Launch the interface
iface.launch()