File size: 3,879 Bytes
d7a54c3
 
 
 
 
 
 
 
c6cda2e
a64dda4
 
5e710c8
a64dda4
 
 
 
 
 
 
 
6a4b3a2
 
c6cda2e
a64dda4
d7a54c3
bc256ab
d7a54c3
 
 
 
 
 
 
bc256ab
 
 
c6cda2e
a64dda4
d7a54c3
bc256ab
 
d7a54c3
c6cda2e
d7a54c3
 
 
 
 
 
c6cda2e
 
d7a54c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio  # for the interface
import transformers  # to load an LLM
import sentence_transformers  # to load an embedding model
import faiss  # to create an index
import numpy  # to work with vectors
import pandas  # to work with pandas
import json  # to work with JSON
import datasets  # to load the dataset

# Load the dataset and convert to pandas
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()

# Define the base URL for Google Scholar
SCHOLAR_URL = "https://scholar.google.com"

# Filter out any publications without an abstract
filter = [
    '"abstract": null' in json.dumps(bibdict)
    for bibdict in full_data["bib_dict"].values
]
data = full_data[~pandas.Series(filter)]
data.reset_index(inplace=True)

# Create a FAISS index for fast similarity search
metric = faiss.METRIC_INNER_PRODUCT
vectors = numpy.stack(data["embedding"].tolist(), axis=0)
gpu_index = faiss.IndexFlatL2(len(data["embedding"][0]))
# res = faiss.StandardGpuResources()  # use a single GPU
# gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
gpu_index.metric_type = metric
faiss.normalize_L2(vectors)
gpu_index.train(vectors)
gpu_index.add(vectors)

# Load the model for later use in embeddings
model = sentence_transformers.SentenceTransformer("allenai-specter")

# Define the search function
def search(query: str, k: int) -> tuple[str]:
    query = numpy.expand_dims(model.encode(query), axis=0)
    faiss.normalize_L2(query)
    D, I = gpu_index.search(query, k)
    top_five = data.loc[I[0]]

    search_results = "You are an AI assistant who delights in helping people" \
        + "learn about research from the Design Research Collective. Here are" \
        + "several really cool abstracts:\n\n"

    references = "\n\n## References\n\n"

    for i in range(k):
        search_results += top_five["bib_dict"].values[i]["abstract"] + "\n"
        references += str(i+1) + ". [" + top_five["bib_dict"].values[i]["title"] + "]" \
            + "(https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i] + ")\n"

    search_results += "\nSummarize the above abstracts as you respond to the following query:"

    print(search_results)

    return search_results, references


# Create an LLM pipeline that we can send queries to
pipe = transformers.pipeline(
    "text-generation",
    model="Qwen/Qwen2-0.5B-Instruct",
    trust_remote_code=True,
    max_new_tokens = 512,
    device="cuda:0",
)

def preprocess(message: str) -> tuple[str]:
    """Applies a preprocessing step to the user's message before the LLM receives it"""
    block_search_results, formatted_search_results = search(message, 5)
    return block_search_results + message, formatted_search_results

def postprocess(response: str, bypass_from_preprocessing: str) -> str:
    """Applies a postprocessing step to the LLM's response before the user receives it"""
    return response + bypass_from_preprocessing

def predict(message: str, history: list[str]) -> str:
    """This function is responsible for crafting a response"""

    # Apply preprocessing
    message, bypass = preprocess(message)

    # This is some handling that is applied to the history variable to put it in a good format
    if isinstance(history, list):
        if len(history) > 0:
            history = history[-1]
    history_transformer_format = [
        {"role": "assistant" if idx&1 else "user", "content": msg}
        for idx, msg in enumerate(history)
    ] + [{"role": "user", "content": message}]

    # Create a response
    response = pipe(history_transformer_format)
    response_message = response[0]["generated_text"][-1]["content"]

    # Apply postprocessing
    response_message = postprocess(response_message, bypass)

    return response_message


# Create and run the gradio interface
gradio.ChatInterface(predict).launch(debug=True)