Fred808 commited on
Commit
be7f0e5
·
verified ·
1 Parent(s): c80a7e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -4
app.py CHANGED
@@ -5,7 +5,6 @@ from transformers import (
5
  pipeline,
6
  AutoModelForSequenceClassification,
7
  AutoTokenizer,
8
- AutoModelForSeq2SeqLM,
9
  AutoModelForCausalLM,
10
  T5Tokenizer,
11
  T5ForConditionalGeneration,
@@ -128,7 +127,7 @@ class Chatbot:
128
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
129
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
130
 
131
- def generate_response(self, prompt, max_length=50):
132
  """
133
  Generates a response to a user query using GPT-2.
134
  """
@@ -137,6 +136,33 @@ class Chatbot:
137
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
138
  return response
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  # Initialize models
142
  classifier = ContentClassifier()
@@ -259,8 +285,18 @@ async def chat(request: PromptRequest):
259
  if not prompt:
260
  raise HTTPException(status_code=400, detail="No prompt provided")
261
 
262
- response = chatbot.generate_response(prompt)
263
- return {"response": response}
 
 
 
 
 
 
 
 
 
 
264
 
265
 
266
  # Start the FastAPI app
 
5
  pipeline,
6
  AutoModelForSequenceClassification,
7
  AutoTokenizer,
 
8
  AutoModelForCausalLM,
9
  T5Tokenizer,
10
  T5ForConditionalGeneration,
 
127
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
128
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
129
 
130
+ def generate_response(self, prompt, max_length=100):
131
  """
132
  Generates a response to a user query using GPT-2.
133
  """
 
136
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
137
  return response
138
 
139
+ def handle_request(self, prompt):
140
+ """
141
+ Handles user requests by determining the intent and delegating to the appropriate function.
142
+ """
143
+ # Check if the user wants to search for something
144
+ if "search" in prompt.lower():
145
+ query = prompt.lower().replace("search", "").strip()
146
+ results = search_engine.search(query)
147
+ return {"type": "search", "results": results}
148
+
149
+ # Check if the user wants a summary
150
+ elif "summarize" in prompt.lower() or "summary" in prompt.lower():
151
+ text = prompt.lower().replace("summarize", "").replace("summary", "").strip()
152
+ summary = summarizer.summarize(text)
153
+ return {"type": "summary", "summary": summary}
154
+
155
+ # Check if the user wants to extract topics
156
+ elif "topics" in prompt.lower() or "topic" in prompt.lower():
157
+ text = prompt.lower().replace("topics", "").replace("topic", "").strip()
158
+ topics = topic_extractor.extract_topics([text])
159
+ return {"type": "topics", "topics": topics.to_dict()}
160
+
161
+ # Default to generating a conversational response
162
+ else:
163
+ response = self.generate_response(prompt)
164
+ return {"type": "chat", "response": response}
165
+
166
 
167
  # Initialize models
168
  classifier = ContentClassifier()
 
285
  if not prompt:
286
  raise HTTPException(status_code=400, detail="No prompt provided")
287
 
288
+ # Handle the request using the chatbot's handle_request method
289
+ result = chatbot.handle_request(prompt)
290
+
291
+ # Return the appropriate response based on the type of request
292
+ if result["type"] == "search":
293
+ return {"type": "search", "results": result["results"]}
294
+ elif result["type"] == "summary":
295
+ return {"type": "summary", "summary": result["summary"]}
296
+ elif result["type"] == "topics":
297
+ return {"type": "topics", "topics": result["topics"]}
298
+ else:
299
+ return {"type": "chat", "response": result["response"]}
300
 
301
 
302
  # Start the FastAPI app