File size: 2,460 Bytes
33ffdb4
 
 
 
 
 
a3f8683
4d32fbd
33ffdb4
20309d7
4d32fbd
 
 
 
ee678e0
 
 
20309d7
5cdd539
33ffdb4
 
4d32fbd
ee678e0
4d32fbd
 
1ad4d26
 
33ffdb4
 
4d32fbd
 
 
33ffdb4
 
33a95fe
33ffdb4
 
 
 
 
 
 
 
71519e7
 
 
 
 
33ffdb4
71519e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import spaces
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import lancedb

os.environ["HF_TOKEN"] = os.getenv("auth")

db = lancedb.connect("embedding_dataset")
tbl = db.open_table("my_table")

embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device="cuda")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16, device_map="auto")

@spaces.GPU()
def process_query(query):
    query_embedding = embedding_model.encode(query)
    search_hits = tbl.search(query_embedding).metric("cosine").limit(5).to_list()

    context = search_hits[0]["text"]
    url = search_hits[0]["url"]

    print(url)
    
    input_text = (
        f"You are being provided a query: {query}"
        f"YOu are being provided context to the query: {context}"
        "Please provide a detailed and contextually relevant response."
    )
    
    input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
    len_text = len(input_text)
    with torch.inference_mode():
        generated_outputs = model.generate(**input_ids, max_new_tokens=1000, do_sample=False)
        generated_outputs = tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
    
    response = generated_outputs[0][len_text:]
    return url, response

# demo = gr.Interface(
#     fn=process_query,
#     inputs=gr.Textbox(label="User Query"),
#     outputs=[gr.Textbox(label="URL"), gr.Textbox(label="Generated Response")]
# )

# demo.launch()


demo = gr.Blocks()

with demo:
    gr.Markdown("# RAG on PyImageSearch blog posts")
    gr.Markdown("This interface processes a user query by finding the most relevant context from PyImageSearch and generating a detailed response.")
    
    with gr.Row():
        with gr.Column():
            user_query = gr.Textbox(label="User Query", placeholder="Enter your query here...", lines=2)
        with gr.Column():
            search_url = gr.Textbox(label="URL", interactive=False)
            generated_response = gr.Textbox(label="Generated Response", interactive=False)
    
    submit_button = gr.Button("Submit")

    submit_button.click(
        fn=process_query,
        inputs=user_query,
        outputs=[search_url, generated_response]
    )

demo.launch()