import gradio # Interface handling import spaces # For GPU import langchain_community.vectorstores # Vectorstore for publications import langchain_huggingface # Embeddings import transformers # The number of publications to retrieve for the prompt PUBLICATIONS_TO_RETRIEVE = 5 # The template for the RAG prompt RAG_TEMPLATE = """You are an AI assistant who enjoys helping users learn about research. Answer the USER_QUERY on additive manufacturing research using the RESEARCH_EXCERPTS. Provide a concise ANSWER based on these excerpts. Avoid listing references. ===== RESEARCH_EXCERPTS ===== {research_excerpts} ===== USER_QUERY ===== {query} ===== ANSWER ===== """ # Load vectorstore of SFF publications publication_vectorstore = langchain_community.vectorstores.FAISS.load_local( folder_path="publication_vectorstore", embeddings=langchain_huggingface.HuggingFaceEmbeddings( model_name="all-MiniLM-L12-v2", model_kwargs={"device": "cuda"}, encode_kwargs={"normalize_embeddings": False}, ), allow_dangerous_deserialization=True, ) # # # Create the callable LLM # llm = transformers.pipeline( # task="text-generation", # model="Qwen/Qwen2.5-7B-Instruct-AWQ", # device="cuda", # ) def preprocess(query: str) -> str: """ Generates a prompt based on the top k documents matching the query. Args: query (str): The user's query. Returns: str: The formatted prompt containing research excerpts and the user's query. """ # Search for the top k documents matching the query documents = publication_vectorstore.search( query, k=PUBLICATIONS_TO_RETRIEVE, search_type="similarity" ) # Extract the page content from the documents research_excerpts = [f'"... {doc.page_content}..."' for doc in documents] # Format the prompt with the research excerpts and the user's query prompt = RAG_TEMPLATE.format( research_excerpts="\n\n".join(research_excerpts), query=query ) return prompt import threading @spaces.GPU def reply(message: str, history: list[str]) -> str: """ Generates a response to the user’s message. Args: message (str): The user's message or query. history (list[str]): The conversation history. Returns: str: The generated response from the language model. """ tok = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-AWQ") model = transformers.AutoModelForCausalLM.from_pretrained( "Qwen/Qwen2.5-7B-Instruct-AWQ" ) inputs = tok([preprocess(message)], return_tensors="pt") streamer = transformers.TextIteratorStreamer(tok) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text # yield llm( # preprocess(message), # max_new_tokens=512, # return_full_text=False, # streamer=transformers.TextIteratorStreamer( # transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-AWQ") # ), # )[0]["generated_text"] # Example Queries for Interface EXAMPLE_QUERIES = [ {"text": "What is multi-material 3D printing?"}, {"text": "How is additive manufacturing being applied in aerospace?"}, {"text": "Tell me about innovations in metal 3D printing techniques."}, {"text": "What are some sustainable materials for 3D printing?"}, { "text": "What are the biggest challenges with support structures in additive manufacturing?" }, {"text": "How is 3D printing impacting the medical field?"}, { "text": "What are some common applications of additive manufacturing in industry?" }, {"text": "What are the benefits and limitations of using polymers in 3D printing?"}, {"text": "Tell me about the environmental impacts of additive manufacturing."}, {"text": "What are the primary limitations of current 3D printing technologies?"}, {"text": "How are researchers improving the speed of 3D printing processes?"}, { "text": "What are the best practices for managing post-processing in additive manufacturing?" }, ] # Run the Gradio Interface gradio.ChatInterface( reply, examples=EXAMPLE_QUERIES, cache_examples=False, chatbot=gradio.Chatbot( show_label=False, show_share_button=False, show_copy_button=False, bubble_full_width=False, ), ).launch(debug=True)