File size: 3,882 Bytes
a485a6e
 
 
 
 
 
e9ff287
3a28b7b
9dfdb8b
a485a6e
 
fe7f2e1
ffc9504
a485a6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782bcf3
a485a6e
 
9dfdb8b
 
 
 
 
57aa44a
 
9dfdb8b
 
 
 
 
 
e2ec042
 
 
9dfdb8b
0fa8b59
e2c8568
9dfdb8b
 
 
04a10a3
f9a30d0
 
5a541ea
9423e5e
01187bb
ac84a99
40fe59a
04a10a3
 
a485a6e
 
 
 
9dfdb8b
e2c8568
9dfdb8b
 
3cd3698
a485a6e
 
 
 
 
4b4d8aa
 
a485a6e
04a10a3
a485a6e
 
 
 
 
 
37482f6
a485a6e
7e11f55
a485a6e
 
 
 
 
9ac0075
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from datasets import load_dataset, load_from_disk, Dataset
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd
import base64
import html
import re
from functools import *

model_ckpt = "nomic-ai/nomic-embed-text-v1.5"
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
model_ckpt = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"

tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt, trust_remote_code=True)

device = torch.device("cpu")
model.to(device)


def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)


embeddings_dataset = load_dataset("Vadim212/doctest2")["train"]
embeddings_dataset.add_faiss_index(column="embeddings")

def filter_f(x, field, value):
    return x[field] == value or x[field] is None

def filter_product(question, dataset):
    product_line = ""
    for pl in ["report", "dashboard", "form"]:
        if re.search(f"\\b({pl}|{pl}s|{pl}`s)\\b", question, re.IGNORECASE) is not None: product_line  = f"{pl}s"

    product = ""
    if re.search("\\b(javascript|java script)\\b", question, re.IGNORECASE) is not None: product = "js"
    for pr in ["web", "net", "js", "wpf", "php", "blazor", "angular", "java"]:
        if re.search(f"{pr}\\b", question, re.IGNORECASE) is not None: product = pr

    if (product_line != "") or (product != ""):
        product_line_filter = partial(filter_f, field = "product_line", value = product_line) if product_line != "" else lambda x: True
        product_filter = partial(filter_f, field = "product", value = product) if product != "" else lambda x: True

        if len(dataset.list_indexes()) > 0 : dataset.drop_index("embeddings")
        dataset = dataset.filter(lambda x: product_line_filter(x) and product_filter(x))       

    return dataset

def get_html(row):
    pr = row["product"]
    product_ = f" [{pr}]" if pr is not None else ""
    product_line = f" [{row.product_line}]" if row.product_line is not None else ""
    path = f"[{row.path}]" if row.path is not None and len(row.path) > 0 else ""
    result = f"""<div style='font-size:16pt'>{html.escape(row.title)} <font style='font-size:10pt'>[{row.origin}]{product_line}{product_}({row.scores})</font> </div>
             <div style='font-size:8pt'>{html.escape(path)}</div>
             {html.escape(row.content)}<br><a href='{row.url}' target='_blank'>Link</a>"""
    return result

def find(question):

    question_embedding = get_embeddings([question]).cpu().detach().numpy()

    dataset = filter_product(question, embeddings_dataset)
    if len(dataset.list_indexes()) == 0: dataset.add_faiss_index(column="embeddings")

    scores, samples = dataset.get_nearest_examples(
        "embeddings", question_embedding, k=20
    )

    samples_df = pd.DataFrame.from_dict(samples)
    samples_df["scores"] = scores
    samples_df.sort_values("scores", ascending=True, inplace=True)
    samples_df.drop_duplicates(subset=["url"], inplace=True)
    samples_df = samples_df[:10]
    
    result = [get_html(row) for i, row in samples_df.iterrows()]
    return result


demo = gr.Blocks()

with demo:
    inp = gr.Textbox(placeholder="Enter prompt",label= "Prompt like: 'how to export to PDF?', 'What is report?', 'blazor designer', 'how to render report?', 'List of export formats', 'Supported databases'" )
    find_btn = gr.Button("Find")
    big_block = [gr.HTML("") for i in range(10)]
    
    find_btn.click(find, 
               inputs=inp, 
               outputs=big_block)

demo.launch(inline=True, width=400)