import os import time from typing import Dict, List, Optional, TypeAlias import gradio as gr import torch import weave from transformers import pipeline from papersai.utils import load_paper_as_context os.environ["TOKENIZERS_PARALLELISM"] = "false" HistoryType: TypeAlias = List[Dict[str, str]] # Initialize the LLM and Weave client client = weave.init("papersai") checkpoint: str = "HuggingFaceTB/SmolLM2-135M-Instruct" pipe = pipeline( model=checkpoint, torch_dtype=torch.bfloat16, device_map="auto", ) class ChatState: """Utility class to store context and last response""" def __init__(self): self.context = None self.last_response = None def record_feedback(x: gr.LikeData) -> None: """ Logs user feedback on the assistant's response in the form of a like/dislike reaction. Reference: * https://weave-docs.wandb.ai/guides/tracking/feedback Args: x (gr.LikeData): User feedback data Returns: None """ call = state.last_response # Remove any existing feedback before adding new feedback for existing_feedback in list(call.feedback): call.feedback.purge(existing_feedback.id) if x.liked: call.feedback.add_reaction("👍") else: call.feedback.add_reaction("👎") @weave.op() def invoke(history: HistoryType): """ Simple wrapper around llm inference wrapped in a weave op Args: history (HistoryType): Chat history Returns: BaseMessage: Response from the model """ input_text = pipe.tokenizer.apply_chat_template( history, tokenize=False, ) response = pipe(input_text, do_sample=True, top_p=0.95, max_new_tokens=100)[0][ "generated_text" ] response = response.split("\nassistant\n")[-1] return response def update_state(history: HistoryType, message: Optional[Dict[str, str]]): """ Update history and app state with the latest user input. Args: history (HistoryType): Chat history message (Optional[Dict[str, str]]): User input message Returns: Tuple[HistoryType, gr.MultimodalTextbox]: Updated history and chat input """ if message is None: return history, gr.MultimodalTextbox(value=None, interactive=True) # Initialize history if None if history is None: history = [] # Handle file uploads without adding to visible history if isinstance(message, dict) and "files" in message: for file_path in message["files"]: try: text = load_paper_as_context(file_path=file_path) doc_context = [x.get_content() for x in text] state.context = " ".join(doc_context)[ : pipe.model.config.max_position_embeddings ] history.append( {"role": "system", "content": f"Context: {state.context}\n"} ) except Exception as e: history.append( {"role": "assistant", "content": f"Error loading file: {str(e)}"} ) # Handle text input if isinstance(message, dict) and message.get("text"): history.append({"role": "user", "content": message["text"]}) return history, gr.MultimodalTextbox(value=None, interactive=True) def bot(history: HistoryType): """ Generate response from the LLM and stream it back to the user. Args: history (HistoryType): Chat history Yields: response from the LLM """ if not history: return history try: # Get response from LLM response, call = invoke.call(history) state.last_response = call # Add empty assistant message history.append({"role": "assistant", "content": ""}) # Stream the response for character in response: history[-1]["content"] += character time.sleep(0.02) yield history except Exception as e: history.append({"role": "assistant", "content": f"Error: {str(e)}"}) yield history def create_interface(): with gr.Blocks() as demo: global state state = ChatState() gr.Markdown( """

papers.ai

""", ) chatbot = gr.Chatbot( show_label=False, height=600, type="messages", show_copy_all_button=True, placeholder="Upload a research paper and ask questions!!", ) chat_input = gr.MultimodalTextbox( interactive=True, file_count="single", placeholder="Upload a document or type your message...", show_label=False, ) chat_msg = chat_input.submit( fn=update_state, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], ) bot_msg = chat_msg.then( # noqa: F841 fn=bot, inputs=[chatbot], outputs=chatbot, api_name="bot_response" ) chatbot.like( fn=record_feedback, inputs=None, outputs=None, like_user_message=True, ) return demo def main(): demo = create_interface() demo.launch(share=False) if __name__ == "__main__": main()