Spaces:
Sleeping
Sleeping
""" | |
This script sets up a Gradio interface for querying an AI assistant about additive manufacturing research. | |
It uses a vectorstore to retrieve relevant research excerpts and a language model to generate responses. | |
Modules: | |
- gradio: Interface handling | |
- spaces: For GPU | |
- transformers: LLM Loading | |
- langchain_community.vectorstores: Vectorstore for publications | |
- langchain_huggingface: Embeddings | |
Constants: | |
- PUBLICATIONS_TO_RETRIEVE: The number of publications to retrieve for the prompt | |
- RAG_TEMPLATE: The template for the RAG prompt | |
Functions: | |
- preprocess(query: str) -> str: Generates a prompt based on the top k documents matching the query. | |
- reply(message: str, history: list[str]) -> str: Generates a response to the user’s message. | |
Example Queries: | |
- "What is multi-material 3D printing?" | |
- "How is additive manufacturing being applied in aerospace?" | |
- "Tell me about innovations in metal 3D printing techniques." | |
- "What are some sustainable materials for 3D printing?" | |
- "What are the biggest challenges with support structures in additive manufacturing?" | |
- "How is 3D printing impacting the medical field?" | |
- "What are some common applications of additive manufacturing in industry?" | |
- "What are the benefits and limitations of using polymers in 3D printing?" | |
- "Tell me about the environmental impacts of additive manufacturing." | |
- "What are the primary limitations of current 3D printing technologies?" | |
- "How are researchers improving the speed of 3D printing processes?" | |
- "What are the best practices for managing post-processing in additive manufacturing?" | |
""" | |
import gradio # Interface handling | |
import spaces # For GPU | |
import langchain_community.vectorstores # Vectorstore for publications | |
import langchain_huggingface # Embeddings | |
import transformers | |
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
# The number of publications to retrieve for the prompt | |
PUBLICATIONS_TO_RETRIEVE = 5 | |
# The template for the RAG prompt | |
RAG_TEMPLATE = """You are an AI assistant who enjoys helping users learn about research. | |
Answer the USER_QUERY on additive manufacturing research using the RESEARCH_EXCERPTS. | |
Provide a concise ANSWER based on these excerpts. Avoid listing references. | |
===== RESEARCH_EXCERPTS ===== | |
{research_excerpts} | |
===== USER_QUERY ===== | |
{query} | |
===== ANSWER ===== | |
""" | |
# Load vectorstore of SFF publications | |
publication_vectorstore = langchain_community.vectorstores.FAISS.load_local( | |
folder_path="publication_vectorstore", | |
embeddings=langchain_huggingface.HuggingFaceEmbeddings( | |
model_name="all-MiniLM-L12-v2", | |
model_kwargs={"device": "cuda"}, | |
encode_kwargs={"normalize_embeddings": False}, | |
), | |
allow_dangerous_deserialization=True, | |
) | |
# Create the callable LLM | |
llm = transformers.pipeline( | |
task="text-generation", model="Qwen/Qwen2.5-7B-Instruct-AWQ", device="cuda", attn_implementation="flash_attention_2" | |
) | |
def preprocess(query: str) -> str: | |
""" | |
Generates a prompt based on the top k documents matching the query. | |
Args: | |
query (str): The user's query. | |
Returns: | |
str: The formatted prompt containing research excerpts and the user's query. | |
""" | |
# Search for the top k documents matching the query | |
documents = publication_vectorstore.search( | |
query, k=PUBLICATIONS_TO_RETRIEVE, search_type="similarity" | |
) | |
# Extract the page content from the documents | |
research_excerpts = [f'"... {doc.page_content}..."' for doc in documents] | |
# Format the prompt with the research excerpts and the user's query | |
prompt = RAG_TEMPLATE.format( | |
research_excerpts="\n\n".join(research_excerpts), query=query | |
) | |
# Print the prompt for debugging purposes | |
print(prompt) | |
return prompt | |
def reply(message: str, history: list[str]) -> str: | |
""" | |
Generates a response to the user’s message. | |
Args: | |
message (str): The user's message or query. | |
history (list[str]): The conversation history. | |
Returns: | |
str: The generated response from the language model. | |
""" | |
return llm( | |
preprocess(message), | |
max_new_tokens=512, | |
return_full_text=False, | |
)[ | |
0 | |
]["generated_text"] | |
# Example Queries for Interface | |
EXAMPLE_QUERIES = [ | |
"What is multi-material 3D printing?", | |
"How is additive manufacturing being applied in aerospace?", | |
"Tell me about innovations in metal 3D printing techniques.", | |
"What are some sustainable materials for 3D printing?", | |
"What are the biggest challenges with support structures in additive manufacturing?", | |
"How is 3D printing impacting the medical field?", | |
"What are some common applications of additive manufacturing in industry?", | |
"What are the benefits and limitations of using polymers in 3D printing?", | |
"Tell me about the environmental impacts of additive manufacturing.", | |
"What are the primary limitations of current 3D printing technologies?", | |
"How are researchers improving the speed of 3D printing processes?", | |
"What are the best practices for managing post-processing in additive manufacturing?", | |
] | |
# Run the Gradio Interface | |
gradio.ChatInterface( | |
reply, | |
examples=EXAMPLE_QUERIES, | |
cache_examples=False, | |
chatbot=gradio.Chatbot( | |
show_label=False, | |
show_share_button=False, | |
show_copy_button=False, | |
bubble_full_width=False, | |
), | |
).launch(debug=True) | |