|
import gradio as gr |
|
import spaces |
|
import torch |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import os |
|
|
|
dataset = load_dataset("ariG23498/pis-blogs-chunked") |
|
embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2") |
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", key="HF_TOKEN", value=os.getenv("auth")) |
|
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16, key="HF_TOKEN", value=os.getenv("auth")) |
|
|
|
@spaces.GPU |
|
def process_query(query): |
|
embedding_model = embedding_model.to("cuda") |
|
text_embeddings = embedding_model.encode(dataset["train"]["text"]) |
|
|
|
query_embedding = embedding_model.encode(query) |
|
similarity_scores = embedding_model.similarity(query_embedding, text_embeddings) |
|
top_indices = (-similarity_scores).argsort()[0][:5] |
|
|
|
context = dataset["train"]["text"][top_indices[0]] |
|
url = dataset["train"]["url"][top_indices[0]] |
|
|
|
input_text = ( |
|
f"Based on the context provided, '{context}', how would" |
|
f"you address the user's query regarding '{query}'? Please" |
|
" provide a detailed and contextually relevant response." |
|
) |
|
|
|
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") |
|
len_text = len(input_text) |
|
model = model.to("cuda") |
|
with torch.inference_mode(): |
|
generated_outputs = model.generate(**input_ids, max_new_tokens=1000, do_sample=False) |
|
generated_outputs = tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) |
|
|
|
response = generated_outputs[0][len_text:] |
|
return url, response |
|
|
|
demo = gr.Interface( |
|
fn=process_query, |
|
inputs=gr.Textbox(label="User Query"), |
|
outputs=[gr.Textbox(label="URL"), gr.Textbox(label="Generated Response")] |
|
) |
|
|
|
demo.launch() |
|
|