test2 / app.py
Vadim212's picture
Update app.py
3cd3698 verified
import gradio as gr
from datasets import load_dataset, load_from_disk, Dataset
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd
import base64
import html
import re
from functools import *
model_ckpt = "nomic-ai/nomic-embed-text-v1.5"
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
model_ckpt = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt, trust_remote_code=True)
device = torch.device("cpu")
model.to(device)
def cls_pooling(model_output):
return model_output.last_hidden_state[:, 0]
def get_embeddings(text_list):
encoded_input = tokenizer(
text_list, padding=True, truncation=True, return_tensors="pt"
)
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
model_output = model(**encoded_input)
return cls_pooling(model_output)
embeddings_dataset = load_dataset("Vadim212/doctest2")["train"]
embeddings_dataset.add_faiss_index(column="embeddings")
def filter_f(x, field, value):
return x[field] == value or x[field] is None
def filter_product(question, dataset):
product_line = ""
for pl in ["report", "dashboard", "form"]:
if re.search(f"\\b({pl}|{pl}s|{pl}`s)\\b", question, re.IGNORECASE) is not None: product_line = f"{pl}s"
product = ""
if re.search("\\b(javascript|java script)\\b", question, re.IGNORECASE) is not None: product = "js"
for pr in ["web", "net", "js", "wpf", "php", "blazor", "angular", "java"]:
if re.search(f"{pr}\\b", question, re.IGNORECASE) is not None: product = pr
if (product_line != "") or (product != ""):
product_line_filter = partial(filter_f, field = "product_line", value = product_line) if product_line != "" else lambda x: True
product_filter = partial(filter_f, field = "product", value = product) if product != "" else lambda x: True
if len(dataset.list_indexes()) > 0 : dataset.drop_index("embeddings")
dataset = dataset.filter(lambda x: product_line_filter(x) and product_filter(x))
return dataset
def get_html(row):
pr = row["product"]
product_ = f" [{pr}]" if pr is not None else ""
product_line = f" [{row.product_line}]" if row.product_line is not None else ""
path = f"[{row.path}]" if row.path is not None and len(row.path) > 0 else ""
result = f"""<div style='font-size:16pt'>{html.escape(row.title)} <font style='font-size:10pt'>[{row.origin}]{product_line}{product_}({row.scores})</font> </div>
<div style='font-size:8pt'>{html.escape(path)}</div>
{html.escape(row.content)}<br><a href='{row.url}' target='_blank'>Link</a>"""
return result
def find(question):
question_embedding = get_embeddings([question]).cpu().detach().numpy()
dataset = filter_product(question, embeddings_dataset)
if len(dataset.list_indexes()) == 0: dataset.add_faiss_index(column="embeddings")
scores, samples = dataset.get_nearest_examples(
"embeddings", question_embedding, k=20
)
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=True, inplace=True)
samples_df.drop_duplicates(subset=["url"], inplace=True)
samples_df = samples_df[:10]
result = [get_html(row) for i, row in samples_df.iterrows()]
return result
demo = gr.Blocks()
with demo:
inp = gr.Textbox(placeholder="Enter prompt",label= "Prompt like: 'how to export to PDF?', 'What is report?', 'blazor designer', 'how to render report?', 'List of export formats', 'Supported databases'" )
find_btn = gr.Button("Find")
big_block = [gr.HTML("") for i in range(10)]
find_btn.click(find,
inputs=inp,
outputs=big_block)
demo.launch(inline=True, width=400)