Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,6 @@ from transformers import (
|
|
5 |
pipeline,
|
6 |
AutoModelForSequenceClassification,
|
7 |
AutoTokenizer,
|
8 |
-
AutoModelForSeq2SeqLM,
|
9 |
AutoModelForCausalLM,
|
10 |
T5Tokenizer,
|
11 |
T5ForConditionalGeneration,
|
@@ -128,7 +127,7 @@ class Chatbot:
|
|
128 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
129 |
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
130 |
|
131 |
-
def generate_response(self, prompt, max_length=
|
132 |
"""
|
133 |
Generates a response to a user query using GPT-2.
|
134 |
"""
|
@@ -137,6 +136,33 @@ class Chatbot:
|
|
137 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
138 |
return response
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
# Initialize models
|
142 |
classifier = ContentClassifier()
|
@@ -259,8 +285,18 @@ async def chat(request: PromptRequest):
|
|
259 |
if not prompt:
|
260 |
raise HTTPException(status_code=400, detail="No prompt provided")
|
261 |
|
262 |
-
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
|
266 |
# Start the FastAPI app
|
|
|
5 |
pipeline,
|
6 |
AutoModelForSequenceClassification,
|
7 |
AutoTokenizer,
|
|
|
8 |
AutoModelForCausalLM,
|
9 |
T5Tokenizer,
|
10 |
T5ForConditionalGeneration,
|
|
|
127 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
128 |
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
129 |
|
130 |
+
def generate_response(self, prompt, max_length=100):
|
131 |
"""
|
132 |
Generates a response to a user query using GPT-2.
|
133 |
"""
|
|
|
136 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
137 |
return response
|
138 |
|
139 |
+
def handle_request(self, prompt):
|
140 |
+
"""
|
141 |
+
Handles user requests by determining the intent and delegating to the appropriate function.
|
142 |
+
"""
|
143 |
+
# Check if the user wants to search for something
|
144 |
+
if "search" in prompt.lower():
|
145 |
+
query = prompt.lower().replace("search", "").strip()
|
146 |
+
results = search_engine.search(query)
|
147 |
+
return {"type": "search", "results": results}
|
148 |
+
|
149 |
+
# Check if the user wants a summary
|
150 |
+
elif "summarize" in prompt.lower() or "summary" in prompt.lower():
|
151 |
+
text = prompt.lower().replace("summarize", "").replace("summary", "").strip()
|
152 |
+
summary = summarizer.summarize(text)
|
153 |
+
return {"type": "summary", "summary": summary}
|
154 |
+
|
155 |
+
# Check if the user wants to extract topics
|
156 |
+
elif "topics" in prompt.lower() or "topic" in prompt.lower():
|
157 |
+
text = prompt.lower().replace("topics", "").replace("topic", "").strip()
|
158 |
+
topics = topic_extractor.extract_topics([text])
|
159 |
+
return {"type": "topics", "topics": topics.to_dict()}
|
160 |
+
|
161 |
+
# Default to generating a conversational response
|
162 |
+
else:
|
163 |
+
response = self.generate_response(prompt)
|
164 |
+
return {"type": "chat", "response": response}
|
165 |
+
|
166 |
|
167 |
# Initialize models
|
168 |
classifier = ContentClassifier()
|
|
|
285 |
if not prompt:
|
286 |
raise HTTPException(status_code=400, detail="No prompt provided")
|
287 |
|
288 |
+
# Handle the request using the chatbot's handle_request method
|
289 |
+
result = chatbot.handle_request(prompt)
|
290 |
+
|
291 |
+
# Return the appropriate response based on the type of request
|
292 |
+
if result["type"] == "search":
|
293 |
+
return {"type": "search", "results": result["results"]}
|
294 |
+
elif result["type"] == "summary":
|
295 |
+
return {"type": "summary", "summary": result["summary"]}
|
296 |
+
elif result["type"] == "topics":
|
297 |
+
return {"type": "topics", "topics": result["topics"]}
|
298 |
+
else:
|
299 |
+
return {"type": "chat", "response": result["response"]}
|
300 |
|
301 |
|
302 |
# Start the FastAPI app
|