ccm commited on
Commit
6f5ab24
·
verified ·
1 Parent(s): b927386

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -102
app.py CHANGED
@@ -1,75 +1,50 @@
1
- import threading # to allow streaming response
 
 
 
 
2
 
3
- import gradio # for the interface
4
- import spaces # for GPU
5
- import transformers # to load an LLM
6
- import langchain_community.vectorstores # to load the publication vectorstore
7
- import langchain_huggingface # for embeddings
8
-
9
- # The greeting message
10
  GREETING = (
11
- "Howdy! "
12
- "I'm an AI agent that uses [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about additive manufacturing research. "
13
- "I still make some mistakes though. "
14
- "What can I tell you about today?"
 
15
  )
16
 
17
- # The embedding model name
18
  EMBEDDING_MODEL_NAME = "all-MiniLM-L12-v2"
19
-
20
- # The LLM model name
21
  LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
22
-
23
- # The number of publications to retrieve
24
  PUBLICATIONS_TO_RETRIEVE = 10
25
 
26
 
27
- def embedding(
28
- model_name: str = "all-MiniLM-L12-v2",
29
- device: str = "cuda",
30
- normalize_embeddings: bool = False,
31
- ) -> langchain_huggingface.HuggingFaceEmbeddings:
32
- """
33
- Get the embedding function
34
- :param model_name: The model name
35
- :type model_name: str
36
- :param device: The device to use
37
- :type device: str
38
- :param normalize_embeddings: Whether to normalize embeddings
39
- :type normalize_embeddings: bool
40
-
41
- :return: The embedding function
42
- :rtype: langchain_huggingface.HuggingFaceEmbeddings
43
- """
44
  return langchain_huggingface.HuggingFaceEmbeddings(
45
- model_name=model_name,
46
  model_kwargs={"device": device},
47
  encode_kwargs={"normalize_embeddings": normalize_embeddings},
48
  )
49
 
50
 
51
  def load_publication_vectorstore() -> langchain_community.vectorstores.FAISS:
52
- """
53
- Load the publication vectorstore
54
- :return: The publication vectorstore
55
- :rtype: langchain_community.vectorstores.FAISS
56
- """
57
- return langchain_community.vectorstores.FAISS.load_local(
58
- folder_path="publication_vectorstore",
59
- embeddings=embedding(),
60
- allow_dangerous_deserialization=True,
61
- )
62
-
63
-
 
64
  publication_vectorstore = load_publication_vectorstore()
65
-
66
- # Create an LLM pipeline that we can send queries to
67
- tokenizer = transformers.AutoTokenizer.from_pretrained(
68
- LLM_MODEL_NAME, trust_remote_code=True
69
- )
70
- streamer = transformers.TextIteratorStreamer(
71
- tokenizer, skip_prompt=True, skip_special_tokens=True
72
- )
73
  chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
74
  LLM_MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True
75
  )
@@ -77,78 +52,60 @@ chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
77
 
78
  def preprocess(query: str, k: int) -> str:
79
  """
80
- Searches the dataset for the top k most relevant papers to the query and returns a prompt and references
81
- Args:
82
- query (str): The user's query
83
- k (int): The number of results to return
84
- Returns:
85
- str: The prompt to be used for the AI
86
  """
87
  documents = publication_vectorstore.search(query, k=k, search_type="similarity")
88
-
89
- prompt = (
90
- "You are an AI assistant who delights in helping people learn about research. "
91
- "Do your best to answer the following question about additive manufacturing research. "
92
- "Do not refuse to answer or mention any issues with the research excerpts. "
93
- "Your main task is to use the RESEARCH_EXCERPTS to provide a concise ANSWER to the USER_QUERY. "
94
- "DO NOT list references at the end of the answer.\n\n"
95
- "===== RESEARCH_EXCERPTS =====:\n{{EXCERPTS_GO_HERE}}\n\n"
96
- "===== USER_QUERY =====:\n{{QUERY_GOES_HERE}}\n\n"
97
  "===== ANSWER =====:\n"
98
  )
99
 
100
- research_excerpts = [
101
- '"... ' + document.page_content + '..."' for document in documents
102
- ]
103
-
104
- prompt = prompt.replace("{{EXCERPTS_GO_HERE}}", "\n\n".join(research_excerpts))
105
- prompt = prompt.replace("{{QUERY_GOES_HERE}}", query)
106
-
107
- print(prompt)
108
 
 
109
  return prompt
110
 
111
 
112
  @spaces.GPU
113
  def reply(message: str, history: list[str]) -> str:
114
  """
115
- This function is responsible for crafting a response
116
- Args:
117
- message (str): The user's message
118
- history (list[str]): The conversation history
119
- Returns:
120
- str: The AI's response
121
  """
122
-
123
- # Apply preprocessing
124
  message = preprocess(message, PUBLICATIONS_TO_RETRIEVE)
125
-
126
- # This is some handling that is applied to the history variable to put it in a good format
127
- history_transformer_format = [
128
  {"role": role, "content": message_pair[idx]}
129
  for message_pair in history
130
  for idx, role in enumerate(["user", "assistant"])
131
  if message_pair[idx] is not None
132
  ] + [{"role": "user", "content": message}]
133
 
134
- # Stream a response from pipe
135
  text = tokenizer.apply_chat_template(
136
- history_transformer_format, tokenize=False, add_generation_prompt=True
137
  )
138
  model_inputs = tokenizer([text], return_tensors="pt").to("cuda")
139
 
140
- generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)
141
- t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
142
- t.start()
143
-
144
- partial_message = ""
145
- for new_token in streamer:
146
- if new_token != "<":
147
- partial_message += new_token
148
- yield partial_message
149
 
150
 
151
- # Example queries
152
  EXAMPLE_QUERIES = [
153
  "What is multi-material 3D printing?",
154
  "How is additive manufacturing being applied in aerospace?",
@@ -164,7 +121,7 @@ EXAMPLE_QUERIES = [
164
  "What are the best practices for managing post-processing in additive manufacturing?",
165
  ]
166
 
167
- # Create and run the gradio interface
168
  gradio.ChatInterface(
169
  reply,
170
  examples=EXAMPLE_QUERIES,
 
1
+ import gradio # Interface handling
2
+ import spaces # For GPU
3
+ import transformers # LLM Loading
4
+ import langchain_community.vectorstores # Vectorstore for publications
5
+ import langchain_huggingface # Embeddings
6
 
7
+ # Greeting message
 
 
 
 
 
 
8
  GREETING = (
9
+ "Howdy! I'm an AI agent that uses "
10
+ "[retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) "
11
+ "to answer questions about additive manufacturing research. "
12
+ "I'm still improving, so bear with me if I make any mistakes. "
13
+ "What can I help you with today?"
14
  )
15
 
16
+ # Constants
17
  EMBEDDING_MODEL_NAME = "all-MiniLM-L12-v2"
 
 
18
  LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
 
 
19
  PUBLICATIONS_TO_RETRIEVE = 10
20
 
21
 
22
+ def embedding(device: str = "cuda", normalize_embeddings: bool = False) -> langchain_huggingface.HuggingFaceEmbeddings:
23
+ """Loads embedding model with specified device and normalization."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  return langchain_huggingface.HuggingFaceEmbeddings(
25
+ model_name=EMBEDDING_MODEL_NAME,
26
  model_kwargs={"device": device},
27
  encode_kwargs={"normalize_embeddings": normalize_embeddings},
28
  )
29
 
30
 
31
  def load_publication_vectorstore() -> langchain_community.vectorstores.FAISS:
32
+ """Load the publication vectorstore safely."""
33
+ try:
34
+ return langchain_community.vectorstores.FAISS.load_local(
35
+ folder_path="publication_vectorstore",
36
+ embeddings=embedding(),
37
+ allow_dangerous_deserialization=True,
38
+ )
39
+ except Exception as e:
40
+ print(f"Error loading vectorstore: {e}")
41
+ return None
42
+
43
+
44
+ # Load vectorstore and models
45
  publication_vectorstore = load_publication_vectorstore()
46
+ tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True)
47
+ streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
48
  chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
49
  LLM_MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True
50
  )
 
52
 
53
  def preprocess(query: str, k: int) -> str:
54
  """
55
+ Generates a prompt based on the top k documents matching the query.
 
 
 
 
 
56
  """
57
  documents = publication_vectorstore.search(query, k=k, search_type="similarity")
58
+ research_excerpts = [f'"... {doc.page_content}..."' for doc in documents]
59
+
60
+ # Prompt template
61
+ prompt_template = (
62
+ "You are an AI assistant who enjoys helping users learn about research. "
63
+ "Answer the following question on additive manufacturing research using the RESEARCH_EXCERPTS. "
64
+ "Provide a concise ANSWER based on these excerpts. Avoid listing references.\n\n"
65
+ "===== RESEARCH_EXCERPTS =====:\n{research_excerpts}\n\n"
66
+ "===== USER_QUERY =====:\n{query}\n\n"
67
  "===== ANSWER =====:\n"
68
  )
69
 
70
+ prompt = prompt_template.format(
71
+ research_excerpts="\n\n".join(research_excerpts), query=query
72
+ )
 
 
 
 
 
73
 
74
+ print(prompt) # Useful for debugging prompt content
75
  return prompt
76
 
77
 
78
  @spaces.GPU
79
  def reply(message: str, history: list[str]) -> str:
80
  """
81
+ Generates a response to the user’s message.
 
 
 
 
 
82
  """
83
+ # Preprocess message
 
84
  message = preprocess(message, PUBLICATIONS_TO_RETRIEVE)
85
+ history_formatted = [
 
 
86
  {"role": role, "content": message_pair[idx]}
87
  for message_pair in history
88
  for idx, role in enumerate(["user", "assistant"])
89
  if message_pair[idx] is not None
90
  ] + [{"role": "user", "content": message}]
91
 
92
+ # Tokenize and prepare model input
93
  text = tokenizer.apply_chat_template(
94
+ history_formatted, tokenize=False, add_generation_prompt=True
95
  )
96
  model_inputs = tokenizer([text], return_tensors="pt").to("cuda")
97
 
98
+ # Generate response directly
99
+ output_tokens = chatmodel.generate(
100
+ **model_inputs, max_new_tokens=512
101
+ )
102
+
103
+ # Decode the output tokens
104
+ response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
105
+ return response
 
106
 
107
 
108
+ # Example Queries for Interface
109
  EXAMPLE_QUERIES = [
110
  "What is multi-material 3D printing?",
111
  "How is additive manufacturing being applied in aerospace?",
 
121
  "What are the best practices for managing post-processing in additive manufacturing?",
122
  ]
123
 
124
+ # Run the Gradio Interface
125
  gradio.ChatInterface(
126
  reply,
127
  examples=EXAMPLE_QUERIES,