Spaces:
Sleeping
Sleeping
import threading # to allow streaming response | |
import time # to pave the deliver of the message | |
import gradio # for the interface | |
import spaces # for GPU | |
import transformers # to load an LLM | |
import langchain_community.vectorstores # to load the publication vectorstore | |
import langchain_huggingface # for embeddings | |
# The greeting message | |
GREETING = ( | |
"Howdy! " | |
"I'm an AI agent that uses [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about additive manufacturing research. " | |
"I still make some mistakes though. " | |
"What can I tell you about today?" | |
) | |
# The embedding model name | |
EMBEDDING_MODEL_NAME = "all-MiniLM-L12-v2" | |
# The LLM model name | |
LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" | |
# The number of publications to retrieve | |
PUBLICATIONS_TO_RETRIEVE = 5 | |
def embedding( | |
model_name: str = "all-MiniLM-L12-v2", | |
device: str = "cuda", | |
normalize_embeddings: bool = False, | |
) -> langchain_huggingface.HuggingFaceEmbeddings: | |
""" | |
Get the embedding function | |
:param model_name: The model name | |
:type model_name: str | |
:param device: The device to use | |
:type device: str | |
:param normalize_embeddings: Whether to normalize embeddings | |
:type normalize_embeddings: bool | |
:return: The embedding function | |
:rtype: langchain_huggingface.HuggingFaceEmbeddings | |
""" | |
return langchain_huggingface.HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs={"device": device}, | |
encode_kwargs={"normalize_embeddings": normalize_embeddings}, | |
) | |
def load_publication_vectorstore() -> langchain_community.vectorstores.FAISS: | |
""" | |
Load the publication vectorstore | |
:return: The publication vectorstore | |
:rtype: langchain_community.vectorstores.FAISS | |
""" | |
return langchain_community.vectorstores.FAISS.load_local( | |
folder_path="publication_vectorstore", | |
embeddings=embedding(), | |
allow_dangerous_deserialization=True, | |
) | |
publication_vectorstore = load_publication_vectorstore() | |
# Create an LLM pipeline that we can send queries to | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
LLM_MODEL_NAME, trust_remote_code=True | |
) | |
streamer = transformers.TextIteratorStreamer( | |
tokenizer, skip_prompt=True, skip_special_tokens=True | |
) | |
chatmodel = transformers.AutoModelForCausalLM.from_pretrained( | |
LLM_MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True | |
) | |
def preprocess(query: str, k: int) -> tuple[str, str]: | |
""" | |
Searches the dataset for the top k most relevant papers to the query and returns a prompt and references | |
Args: | |
query (str): The user's query | |
k (int): The number of results to return | |
Returns: | |
tuple[str, str]: A tuple containing the prompt and references | |
""" | |
documents = publication_vectorstore.search( | |
query, k=PUBLICATIONS_TO_RETRIEVE, search_type="similarity" | |
) | |
prompt = ( | |
"You are an AI assistant who delights in helping people learn about research. " | |
"Do your best to answer the following question about additive manufacturing research. " | |
"Do not refuse to answer or mention any issues with the research excerpts. " | |
"Your main task is to use the RESEARCH_EXCERPTS to provide a concise ANSWER to the USER_QUERY. " | |
"DO NOT list references at the end of the answer.\n\n" | |
"===== RESEARCH_EXCERPTS =====:\n{{EXCERPTS_GO_HERE}}\n\n" | |
"===== USER_QUERY =====:\n{{QUERY_GOES_HERE}}\n\n" | |
"===== ANSWER =====:\n" | |
) | |
research_excerpts = [ | |
'"... ' + document.page_content + '..."' for document in documents | |
] | |
prompt = prompt.replace("{{EXCERPTS_GO_HERE}}", "\n\n".join(research_excerpts)) | |
prompt = prompt.replace("{{QUERY_GOES_HERE}}", query) | |
print(prompt) | |
return prompt, "" | |
def reply(message: str, history: list[str]) -> str: | |
""" | |
This function is responsible for crafting a response | |
Args: | |
message (str): The user's message | |
history (list[str]): The conversation history | |
Returns: | |
str: The AI's response | |
""" | |
# Apply preprocessing | |
message, bypass = preprocess(message, PUBLICATIONS_TO_RETRIEVE) | |
# This is some handling that is applied to the history variable to put it in a good format | |
history_transformer_format = [ | |
{"role": role, "content": message_pair[idx]} | |
for message_pair in history | |
for idx, role in enumerate(["user", "assistant"]) | |
if message_pair[idx] is not None | |
] + [{"role": "user", "content": message}] | |
# Stream a response from pipe | |
text = tokenizer.apply_chat_template( | |
history_transformer_format, tokenize=False, add_generation_prompt=True | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0") | |
generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512) | |
t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_message = "" | |
for new_token in streamer: | |
if new_token != "<": | |
partial_message += new_token | |
yield partial_message | |
# Example queries | |
EXAMPLE_QUERIES = [ | |
"What is multi-material 3D printing?", | |
"How is additive manufacturing being applied in aerospace?", | |
"Tell me about innovations in metal 3D printing techniques.", | |
"What are some sustainable materials for 3D printing?", | |
"What are the biggest challenges with support structures in additive manufacturing?", | |
"How is 3D printing impacting the medical field?", | |
"What are some common applications of additive manufacturing in industry?", | |
"What are the benefits and limitations of using polymers in 3D printing?", | |
"Are there recent breakthroughs in enhancing precision for additive manufacturing?", | |
"Tell me about the environmental impacts of additive manufacturing.", | |
"What are the primary limitations of current 3D printing technologies?", | |
"What future trends are expected in the field of additive manufacturing?", | |
"How are researchers improving the speed of 3D printing processes?", | |
"What are the best practices for managing post-processing in additive manufacturing?", | |
] | |
# Create and run the gradio interface | |
gradio.ChatInterface( | |
reply, | |
examples=EXAMPLE_QUERIES, | |
chatbot=gradio.Chatbot( | |
show_label=False, | |
show_share_button=False, | |
show_copy_button=False, | |
# value=[[None, GREETING]], | |
height="60vh", | |
bubble_full_width=False, | |
), | |
).launch(debug=True) | |