ccm commited on
Commit
d7a54c3
·
verified ·
1 Parent(s): 50d362b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +82 -73
main.py CHANGED
@@ -1,13 +1,11 @@
1
- import json # For stringifying a dict
2
- import random # For selecting a search hint
3
-
4
- import gradio # GUI framework
5
- import datasets # Used to load publication dataset
6
-
7
- import numpy # For a few simple matrix operations
8
- import pandas # Needed for operating on dataset
9
- import sentence_transformers # Needed for query embedding
10
- import faiss # Needed for fast similarity search
11
 
12
  # Load the dataset and convert to pandas
13
  full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
@@ -24,77 +22,88 @@ data = full_data[~pandas.Series(filter)]
24
  data.reset_index(inplace=True)
25
 
26
  # Create a FAISS index for fast similarity search
27
- indices = []
28
- metrics = [faiss.METRIC_INNER_PRODUCT ,faiss.METRIC_L2]
29
- normalization = [True, False]
30
  vectors = numpy.stack(data["embedding"].tolist(), axis=0)
31
- for metric in metrics:
32
- for normal in normalization:
33
- index = faiss.IndexFlatL2(len(data["embedding"][0]))
34
- index.metric_type = metric
35
- if normal:
36
- faiss.normalize_L2(vectors)
37
- index.train(vectors)
38
- index.add(vectors)
39
- indices.append(index)
40
 
41
  # Load the model for later use in embeddings
42
  model = sentence_transformers.SentenceTransformer("allenai-specter")
43
 
44
  # Define the search function
45
- def search(query: str, k: int, n: int):
46
  query = numpy.expand_dims(model.encode(query), axis=0)
47
  faiss.normalize_L2(query)
48
- D, I = indices[n].search(query, k)
49
  top_five = data.loc[I[0]]
50
- search_results = ""
 
 
 
 
 
51
 
52
  for i in range(k):
53
- search_results += "### " + top_five["bib_dict"].values[i]["title"] + "\n\n"
54
- search_results += str(int(100*D[0][i])) + "% relevant "
55
- if top_five["author_pub_id"].values[i] is not None:
56
- search_results += "/ [Full Text](https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i] + ") "
57
- if top_five["citedby_url"].values[i] is not None:
58
- search_results += (
59
- "/ [Cited By](" + SCHOLAR_URL + top_five["citedby_url"].values[i] + ") "
60
- )
61
- if top_five["url_related_articles"].values[i] is not None:
62
- search_results += (
63
- "/ [Related Articles]("
64
- + SCHOLAR_URL
65
- + top_five["url_related_articles"].values[i]
66
- + ") "
67
- )
68
- search_results += "\n\n```bibtex\n"
69
- search_results += (
70
- json.dumps(top_five["bibtex"].values[i], indent=4)
71
- .replace("\\n", "\n")
72
- .replace("\\t", "\t")
73
- .strip('"')
74
- )
75
- search_results += "```\n"
76
- return search_results
77
-
78
-
79
- with gradio.Blocks() as demo:
80
- with gradio.Group():
81
- query = gradio.Textbox(
82
- placeholder = random.choice([
83
- "design for additive manufacturing",
84
- "best practices for agent based modeling",
85
- "arctic environmental science",
86
- "analysis of student teamwork"
87
- ]),
88
- show_label=False,
89
- lines=1,
90
- max_lines=1
91
- )
92
- with gradio.Accordion("Settings", open=False):
93
- k = gradio.Number(10.0, label="Number of results", precision=0)
94
- n = gradio.Radio([True, False], value=True, label="Normalized")
95
- results = gradio.Markdown()
96
- query.change(fn=search, inputs=[query, k, n], outputs=results)
97
- k.change(fn=search, inputs=[query, k, n], outputs=results)
98
- n.change(fn=search, inputs=[query, k, n], outputs=results)
99
-
100
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio # for the interface
2
+ import transformers # to load an LLM
3
+ import sentence_transformers # to load an embedding model
4
+ import faiss # to create an index
5
+ import numpy # to work with vectors
6
+ import pandas # to work with pandas
7
+ import json # to work with JSON
8
+ import datasets # to load the dataset
 
 
9
 
10
  # Load the dataset and convert to pandas
11
  full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
 
22
  data.reset_index(inplace=True)
23
 
24
  # Create a FAISS index for fast similarity search
25
+ metric = faiss.METRIC_INNER_PRODUCT
 
 
26
  vectors = numpy.stack(data["embedding"].tolist(), axis=0)
27
+ gpu_index = faiss.IndexFlatL2(len(data["embedding"][0]))
28
+ # res = faiss.StandardGpuResources() # use a single GPU
29
+ # gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
30
+ gpu_index.metric_type = metric
31
+ faiss.normalize_L2(vectors)
32
+ gpu_index.train(vectors)
33
+ gpu_index.add(vectors)
 
 
34
 
35
  # Load the model for later use in embeddings
36
  model = sentence_transformers.SentenceTransformer("allenai-specter")
37
 
38
  # Define the search function
39
+ def search(query: str, k: int) -> tuple[str]:
40
  query = numpy.expand_dims(model.encode(query), axis=0)
41
  faiss.normalize_L2(query)
42
+ D, I = gpu_index.search(query, k)
43
  top_five = data.loc[I[0]]
44
+
45
+ search_results = "You are an AI assistant who delights in helping people" \
46
+ + "learn about research from the Design Research Collective. Here are" \
47
+ + "several really cool abstracts:\n\n"
48
+
49
+ references = "\n\n## References\n\n"
50
 
51
  for i in range(k):
52
+ search_results += top_five["bib_dict"].values[i]["abstract"] + "\n"
53
+ references += str(i+1) + ". [" + top_five["bib_dict"].values[i]["title"] + "]" \
54
+ + "(https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i] + ")\n"
55
+
56
+ search_results += "\nSummarize the above abstracts as you respond to the following query:"
57
+
58
+ print(search_results)
59
+
60
+ return search_results, references
61
+
62
+
63
+ # Create an LLM pipeline that we can send queries to
64
+ pipe = transformers.pipeline(
65
+ "text-generation",
66
+ model="Qwen/Qwen2-0.5B-Instruct",
67
+ # model="microsoft/Phi-3-medium-128k-instruct-onnx-cuda",
68
+ # model="microsoft/Phi-3-medium-128k-instruct",
69
+ trust_remote_code=True,
70
+ max_new_tokens = 512,
71
+ device="cuda:0",
72
+ )
73
+
74
+ def preprocess(message: str) -> tuple[str]:
75
+ """Applies a preprocessing step to the user's message before the LLM receives it"""
76
+ block_search_results, formatted_search_results = search(message, 5)
77
+ return block_search_results + message, formatted_search_results
78
+
79
+ def postprocess(response: str, bypass_from_preprocessing: str) -> str:
80
+ """Applies a postprocessing step to the LLM's response before the user receives it"""
81
+ return response + bypass_from_preprocessing
82
+
83
+ def predict(message: str, history: list[str]) -> str:
84
+ """This function is responsible for crafting a response"""
85
+
86
+ # Apply preprocessing
87
+ message, bypass = preprocess(message)
88
+
89
+ # This is some handling that is applied to the history variable to put it in a good format
90
+ if isinstance(history, list):
91
+ if len(history) > 0:
92
+ history = history[-1]
93
+ history_transformer_format = [
94
+ {"role": "assistant" if idx&1 else "user", "content": msg}
95
+ for idx, msg in enumerate(history)
96
+ ] + [{"role": "user", "content": message}]
97
+
98
+ # Create a response
99
+ response = pipe(history_transformer_format)
100
+ response_message = response[0]["generated_text"][-1]["content"]
101
+
102
+ # Apply postprocessing
103
+ response_message = postprocess(response_message, bypass)
104
+
105
+ return response_message
106
+
107
+
108
+ # Create and run the gradio interface
109
+ gradio.ChatInterface(predict).launch(debug=True)