ccm commited on
Commit
206bc94
·
verified ·
1 Parent(s): 9cdb1a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -52
app.py CHANGED
@@ -1,71 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio # Interface handling
2
  import spaces # For GPU
3
  import transformers # LLM Loading
4
  import langchain_community.vectorstores # Vectorstore for publications
5
  import langchain_huggingface # Embeddings
6
 
 
 
7
 
8
- # Greeting message
9
- GREETING = (
10
- "Howdy! I'm an AI agent that uses "
11
- "[retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) "
12
- "to answer questions about additive manufacturing research. "
13
- "I'm still improving, so bear with me if I make any mistakes. "
14
- "What can I help you with today?"
15
- )
16
-
17
- # Constants
18
- EMBEDDING_MODEL_NAME = "all-MiniLM-L12-v2"
19
- LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
20
- PUBLICATIONS_TO_RETRIEVE = 10
21
 
 
 
22
 
23
- def embedding(
24
- device: str = "cuda", normalize_embeddings: bool = False
25
- ) -> langchain_huggingface.HuggingFaceEmbeddings:
26
- """Loads embedding model with specified device and normalization."""
27
- return langchain_huggingface.HuggingFaceEmbeddings(
28
- model_name=EMBEDDING_MODEL_NAME,
29
- model_kwargs={"device": device},
30
- encode_kwargs={"normalize_embeddings": normalize_embeddings},
31
- )
32
-
33
 
34
- def load_publication_vectorstore() -> langchain_community.vectorstores.FAISS:
35
- """Load the publication vectorstore safely."""
36
- return langchain_community.vectorstores.FAISS.load_local(
37
- folder_path="publication_vectorstore",
38
- embeddings=embedding(),
39
- allow_dangerous_deserialization=True,
40
- )
41
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Load vectorstore and models
44
- publication_vectorstore = load_publication_vectorstore()
 
 
45
 
46
 
47
- def preprocess(query: str, k: int) -> str:
48
  """
49
  Generates a prompt based on the top k documents matching the query.
 
 
 
 
 
 
50
  """
51
- documents = publication_vectorstore.search(query, k=k, search_type="similarity")
52
- research_excerpts = [f'"... {doc.page_content}..."' for doc in documents]
53
 
54
- # Prompt template
55
- prompt_template = (
56
- "You are an AI assistant who enjoys helping users learn about research. "
57
- "Answer the following question on additive manufacturing research using the RESEARCH_EXCERPTS. "
58
- "Provide a concise ANSWER based on these excerpts. Avoid listing references.\n\n"
59
- "===== RESEARCH_EXCERPTS =====\n{research_excerpts}\n\n"
60
- "===== USER_QUERY =====\n{query}\n\n"
61
- "===== ANSWER =====\n"
62
  )
63
 
64
- prompt = prompt_template.format(
 
 
 
 
65
  research_excerpts="\n\n".join(research_excerpts), query=query
66
  )
67
 
68
- print(prompt) # Useful for debugging prompt content
 
 
69
  return prompt
70
 
71
 
@@ -73,15 +106,22 @@ def preprocess(query: str, k: int) -> str:
73
  def reply(message: str, history: list[str]) -> str:
74
  """
75
  Generates a response to the user’s message.
76
- """
77
- # Preprocess message
78
 
79
- pipe = transformers.pipeline(
80
- "text-generation", model="Qwen/Qwen2.5-7B-Instruct", device="cuda"
81
- )
 
 
 
 
82
 
83
- message = preprocess(message, PUBLICATIONS_TO_RETRIEVE)
84
- return pipe(message, max_new_tokens=512, return_full_text=False)[0]["generated_text"]
 
 
 
 
 
85
 
86
 
87
  # Example Queries for Interface
 
1
+ """
2
+ This script sets up a Gradio interface for querying an AI assistant about additive manufacturing research.
3
+ It uses a vectorstore to retrieve relevant research excerpts and a language model to generate responses.
4
+
5
+ Modules:
6
+ - gradio: Interface handling
7
+ - spaces: For GPU
8
+ - transformers: LLM Loading
9
+ - langchain_community.vectorstores: Vectorstore for publications
10
+ - langchain_huggingface: Embeddings
11
+
12
+ Constants:
13
+ - PUBLICATIONS_TO_RETRIEVE: The number of publications to retrieve for the prompt
14
+ - RAG_TEMPLATE: The template for the RAG prompt
15
+
16
+ Functions:
17
+ - preprocess(query: str) -> str: Generates a prompt based on the top k documents matching the query.
18
+ - reply(message: str, history: list[str]) -> str: Generates a response to the user’s message.
19
+
20
+ Example Queries:
21
+ - "What is multi-material 3D printing?"
22
+ - "How is additive manufacturing being applied in aerospace?"
23
+ - "Tell me about innovations in metal 3D printing techniques."
24
+ - "What are some sustainable materials for 3D printing?"
25
+ - "What are the biggest challenges with support structures in additive manufacturing?"
26
+ - "How is 3D printing impacting the medical field?"
27
+ - "What are some common applications of additive manufacturing in industry?"
28
+ - "What are the benefits and limitations of using polymers in 3D printing?"
29
+ - "Tell me about the environmental impacts of additive manufacturing."
30
+ - "What are the primary limitations of current 3D printing technologies?"
31
+ - "How are researchers improving the speed of 3D printing processes?"
32
+ - "What are the best practices for managing post-processing in additive manufacturing?"
33
+ """
34
+
35
  import gradio # Interface handling
36
  import spaces # For GPU
37
  import transformers # LLM Loading
38
  import langchain_community.vectorstores # Vectorstore for publications
39
  import langchain_huggingface # Embeddings
40
 
41
+ # The number of publications to retrieve for the prompt
42
+ PUBLICATIONS_TO_RETRIEVE = 5
43
 
44
+ # The template for the RAG prompt
45
+ RAG_TEMPLATE = """You are an AI assistant who enjoys helping users learn about research.
46
+ Answer the USER_QUERY on additive manufacturing research using the RESEARCH_EXCERPTS.
47
+ Provide a concise ANSWER based on these excerpts. Avoid listing references.
 
 
 
 
 
 
 
 
 
48
 
49
+ ===== RESEARCH_EXCERPTS =====
50
+ {research_excerpts}
51
 
52
+ ===== USER_QUERY =====
53
+ {query}
 
 
 
 
 
 
 
 
54
 
55
+ ===== ANSWER =====
56
+ """
 
 
 
 
 
57
 
58
+ # Load vectorstore of SFF publications
59
+ publication_vectorstore = langchain_community.vectorstores.FAISS.load_local(
60
+ folder_path="publication_vectorstore",
61
+ embeddings=langchain_huggingface.HuggingFaceEmbeddings(
62
+ model_name="all-MiniLM-L12-v2",
63
+ model_kwargs={"device": "cuda"},
64
+ encode_kwargs={"normalize_embeddings": False},
65
+ ),
66
+ allow_dangerous_deserialization=True,
67
+ )
68
 
69
+ # Create the callable LLM
70
+ llm = transformers.pipeline(
71
+ "text-generation", model="Qwen/Qwen2.5-7B-Instruct", device="cuda"
72
+ )
73
 
74
 
75
+ def preprocess(query: str) -> str:
76
  """
77
  Generates a prompt based on the top k documents matching the query.
78
+
79
+ Args:
80
+ query (str): The user's query.
81
+
82
+ Returns:
83
+ str: The formatted prompt containing research excerpts and the user's query.
84
  """
 
 
85
 
86
+ # Search for the top k documents matching the query
87
+ documents = publication_vectorstore.search(
88
+ query, k=PUBLICATIONS_TO_RETRIEVE, search_type="similarity"
 
 
 
 
 
89
  )
90
 
91
+ # Extract the page content from the documents
92
+ research_excerpts = [f'"... {doc.page_content}..."' for doc in documents]
93
+
94
+ # Format the prompt with the research excerpts and the user's query
95
+ prompt = RAG_TEMPLATE.format(
96
  research_excerpts="\n\n".join(research_excerpts), query=query
97
  )
98
 
99
+ # Print the prompt for debugging purposes
100
+ print(prompt)
101
+
102
  return prompt
103
 
104
 
 
106
  def reply(message: str, history: list[str]) -> str:
107
  """
108
  Generates a response to the user’s message.
 
 
109
 
110
+ Args:
111
+ message (str): The user's message or query.
112
+ history (list[str]): The conversation history.
113
+
114
+ Returns:
115
+ str: The generated response from the language model.
116
+ """
117
 
118
+ return llm(
119
+ preprocess(message),
120
+ max_new_tokens=512,
121
+ return_full_text=False,
122
+ )[
123
+ 0
124
+ ]["generated_text"]
125
 
126
 
127
  # Example Queries for Interface