Github Actions commited on
Commit
bc1817b
Β·
1 Parent(s): f327b00
Files changed (1) hide show
  1. local.py +209 -0
local.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import Dict, List, Optional, TypeAlias
4
+
5
+ import gradio as gr
6
+ import torch
7
+ import weave
8
+ from transformers import pipeline
9
+
10
+ from papersai.utils import load_paper_as_context
11
+
12
+
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+ HistoryType: TypeAlias = List[Dict[str, str]]
16
+
17
+ # Initialize the LLM and Weave client
18
+ client = weave.init("papersai")
19
+ checkpoint: str = "HuggingFaceTB/SmolLM2-135M-Instruct"
20
+ pipe = pipeline(
21
+ model=checkpoint,
22
+ torch_dtype=torch.bfloat16,
23
+ device_map="auto",
24
+ )
25
+
26
+
27
+ class ChatState:
28
+ """Utility class to store context and last response"""
29
+
30
+ def __init__(self):
31
+ self.context = None
32
+ self.last_response = None
33
+
34
+
35
+ def record_feedback(x: gr.LikeData) -> None:
36
+ """
37
+ Logs user feedback on the assistant's response in the form of a
38
+ like/dislike reaction.
39
+
40
+ Reference:
41
+ * https://weave-docs.wandb.ai/guides/tracking/feedback
42
+
43
+ Args:
44
+ x (gr.LikeData): User feedback data
45
+
46
+ Returns:
47
+ None
48
+ """
49
+ call = state.last_response
50
+
51
+ # Remove any existing feedback before adding new feedback
52
+ for existing_feedback in list(call.feedback):
53
+ call.feedback.purge(existing_feedback.id)
54
+
55
+ if x.liked:
56
+ call.feedback.add_reaction("πŸ‘")
57
+ else:
58
+ call.feedback.add_reaction("πŸ‘Ž")
59
+
60
+
61
+ @weave.op()
62
+ def invoke(history: HistoryType):
63
+ """
64
+ Simple wrapper around llm inference wrapped in a weave op
65
+
66
+ Args:
67
+ history (HistoryType): Chat history
68
+
69
+ Returns:
70
+ BaseMessage: Response from the model
71
+ """
72
+ input_text = pipe.tokenizer.apply_chat_template(
73
+ history,
74
+ tokenize=False,
75
+ )
76
+ response = pipe(input_text, do_sample=True, top_p=0.95, max_new_tokens=100)[0][
77
+ "generated_text"
78
+ ]
79
+ response = response.split("\nassistant\n")[-1]
80
+ return response
81
+
82
+
83
+ def update_state(history: HistoryType, message: Optional[Dict[str, str]]):
84
+ """
85
+ Update history and app state with the latest user input.
86
+
87
+ Args:
88
+ history (HistoryType): Chat history
89
+ message (Optional[Dict[str, str]]): User input message
90
+
91
+ Returns:
92
+ Tuple[HistoryType, gr.MultimodalTextbox]: Updated history and chat input
93
+ """
94
+ if message is None:
95
+ return history, gr.MultimodalTextbox(value=None, interactive=True)
96
+
97
+ # Initialize history if None
98
+ if history is None:
99
+ history = []
100
+
101
+ # Handle file uploads without adding to visible history
102
+ if isinstance(message, dict) and "files" in message:
103
+ for file_path in message["files"]:
104
+ try:
105
+ text = load_paper_as_context(file_path=file_path)
106
+ doc_context = [x.get_content() for x in text]
107
+ state.context = " ".join(doc_context)[
108
+ : pipe.model.config.max_position_embeddings
109
+ ]
110
+ history.append(
111
+ {"role": "system", "content": f"Context: {state.context}\n"}
112
+ )
113
+ except Exception as e:
114
+ history.append(
115
+ {"role": "assistant", "content": f"Error loading file: {str(e)}"}
116
+ )
117
+
118
+ # Handle text input
119
+ if isinstance(message, dict) and message.get("text"):
120
+ history.append({"role": "user", "content": message["text"]})
121
+
122
+ return history, gr.MultimodalTextbox(value=None, interactive=True)
123
+
124
+
125
+ def bot(history: HistoryType):
126
+ """
127
+ Generate response from the LLM and stream it back to the user.
128
+
129
+ Args:
130
+ history (HistoryType): Chat history
131
+
132
+ Yields:
133
+ response from the LLM
134
+ """
135
+ if not history:
136
+ return history
137
+
138
+ try:
139
+ # Get response from LLM
140
+ response, call = invoke.call(history)
141
+ state.last_response = call
142
+
143
+ # Add empty assistant message
144
+ history.append({"role": "assistant", "content": ""})
145
+
146
+ # Stream the response
147
+ for character in response:
148
+ history[-1]["content"] += character
149
+ time.sleep(0.02)
150
+ yield history
151
+
152
+ except Exception as e:
153
+ history.append({"role": "assistant", "content": f"Error: {str(e)}"})
154
+ yield history
155
+
156
+
157
+ def create_interface():
158
+ with gr.Blocks() as demo:
159
+ global state
160
+ state = ChatState()
161
+ gr.Markdown(
162
+ """
163
+ <a href="https://github.com/SauravMaheshkar/papersai">
164
+ <div align="center"><h1>papers.ai</h1></div>
165
+ </a>
166
+ """,
167
+ )
168
+ chatbot = gr.Chatbot(
169
+ show_label=False,
170
+ height=600,
171
+ type="messages",
172
+ show_copy_all_button=True,
173
+ placeholder="Upload a research paper and ask questions!!",
174
+ )
175
+
176
+ chat_input = gr.MultimodalTextbox(
177
+ interactive=True,
178
+ file_count="single",
179
+ placeholder="Upload a document or type your message...",
180
+ show_label=False,
181
+ )
182
+
183
+ chat_msg = chat_input.submit(
184
+ fn=update_state,
185
+ inputs=[chatbot, chat_input],
186
+ outputs=[chatbot, chat_input],
187
+ )
188
+
189
+ bot_msg = chat_msg.then( # noqa: F841
190
+ fn=bot, inputs=[chatbot], outputs=chatbot, api_name="bot_response"
191
+ )
192
+
193
+ chatbot.like(
194
+ fn=record_feedback,
195
+ inputs=None,
196
+ outputs=None,
197
+ like_user_message=True,
198
+ )
199
+
200
+ return demo
201
+
202
+
203
+ def main():
204
+ demo = create_interface()
205
+ demo.launch(share=False)
206
+
207
+
208
+ if __name__ == "__main__":
209
+ main()