File size: 3,882 Bytes
a485a6e e9ff287 3a28b7b 9dfdb8b a485a6e fe7f2e1 ffc9504 a485a6e 782bcf3 a485a6e 9dfdb8b 57aa44a 9dfdb8b e2ec042 9dfdb8b 0fa8b59 e2c8568 9dfdb8b 04a10a3 f9a30d0 5a541ea 9423e5e 01187bb ac84a99 40fe59a 04a10a3 a485a6e 9dfdb8b e2c8568 9dfdb8b 3cd3698 a485a6e 4b4d8aa a485a6e 04a10a3 a485a6e 37482f6 a485a6e 7e11f55 a485a6e 9ac0075 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
from datasets import load_dataset, load_from_disk, Dataset
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd
import base64
import html
import re
from functools import *
model_ckpt = "nomic-ai/nomic-embed-text-v1.5"
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
model_ckpt = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt, trust_remote_code=True)
device = torch.device("cpu")
model.to(device)
def cls_pooling(model_output):
return model_output.last_hidden_state[:, 0]
def get_embeddings(text_list):
encoded_input = tokenizer(
text_list, padding=True, truncation=True, return_tensors="pt"
)
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
model_output = model(**encoded_input)
return cls_pooling(model_output)
embeddings_dataset = load_dataset("Vadim212/doctest2")["train"]
embeddings_dataset.add_faiss_index(column="embeddings")
def filter_f(x, field, value):
return x[field] == value or x[field] is None
def filter_product(question, dataset):
product_line = ""
for pl in ["report", "dashboard", "form"]:
if re.search(f"\\b({pl}|{pl}s|{pl}`s)\\b", question, re.IGNORECASE) is not None: product_line = f"{pl}s"
product = ""
if re.search("\\b(javascript|java script)\\b", question, re.IGNORECASE) is not None: product = "js"
for pr in ["web", "net", "js", "wpf", "php", "blazor", "angular", "java"]:
if re.search(f"{pr}\\b", question, re.IGNORECASE) is not None: product = pr
if (product_line != "") or (product != ""):
product_line_filter = partial(filter_f, field = "product_line", value = product_line) if product_line != "" else lambda x: True
product_filter = partial(filter_f, field = "product", value = product) if product != "" else lambda x: True
if len(dataset.list_indexes()) > 0 : dataset.drop_index("embeddings")
dataset = dataset.filter(lambda x: product_line_filter(x) and product_filter(x))
return dataset
def get_html(row):
pr = row["product"]
product_ = f" [{pr}]" if pr is not None else ""
product_line = f" [{row.product_line}]" if row.product_line is not None else ""
path = f"[{row.path}]" if row.path is not None and len(row.path) > 0 else ""
result = f"""<div style='font-size:16pt'>{html.escape(row.title)} <font style='font-size:10pt'>[{row.origin}]{product_line}{product_}({row.scores})</font> </div>
<div style='font-size:8pt'>{html.escape(path)}</div>
{html.escape(row.content)}<br><a href='{row.url}' target='_blank'>Link</a>"""
return result
def find(question):
question_embedding = get_embeddings([question]).cpu().detach().numpy()
dataset = filter_product(question, embeddings_dataset)
if len(dataset.list_indexes()) == 0: dataset.add_faiss_index(column="embeddings")
scores, samples = dataset.get_nearest_examples(
"embeddings", question_embedding, k=20
)
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=True, inplace=True)
samples_df.drop_duplicates(subset=["url"], inplace=True)
samples_df = samples_df[:10]
result = [get_html(row) for i, row in samples_df.iterrows()]
return result
demo = gr.Blocks()
with demo:
inp = gr.Textbox(placeholder="Enter prompt",label= "Prompt like: 'how to export to PDF?', 'What is report?', 'blazor designer', 'how to render report?', 'List of export formats', 'Supported databases'" )
find_btn = gr.Button("Find")
big_block = [gr.HTML("") for i in range(10)]
find_btn.click(find,
inputs=inp,
outputs=big_block)
demo.launch(inline=True, width=400) |