Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from typing import List, Dict, Any
|
4 |
+
from serpapi import GoogleSearch
|
5 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
6 |
+
import torch
|
7 |
+
|
8 |
+
def search_serpapi(query: str, loc: str, api_key: str) -> List[Dict[str, Any]]:
|
9 |
+
"""
|
10 |
+
Search using SerpAPI for the given query and return the results.
|
11 |
+
"""
|
12 |
+
try:
|
13 |
+
search = GoogleSearch({
|
14 |
+
"q": query,
|
15 |
+
"location": loc,
|
16 |
+
"api_key": api_key
|
17 |
+
})
|
18 |
+
results = search.get_dict()
|
19 |
+
return results.get("organic_results", [])
|
20 |
+
except Exception as e:
|
21 |
+
raise Exception(f"An error occurred: {e}")
|
22 |
+
|
23 |
+
def convert_to_md_table(data):
|
24 |
+
md_table = "| Title | Link |\n| :--- | :--- |\n"
|
25 |
+
for item in data:
|
26 |
+
title = item['title']
|
27 |
+
link = item['link']
|
28 |
+
md_table += f"| {title} | [Link]({link}) |\n"
|
29 |
+
return md_table
|
30 |
+
|
31 |
+
# Load model directly
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained("jy46604790/Fake-News-Bert-Detect")
|
33 |
+
model = AutoModelForSequenceClassification.from_pretrained("jy46604790/Fake-News-Bert-Detect")
|
34 |
+
|
35 |
+
def call_classifier(text: str):
|
36 |
+
inputs = tokenizer(text, return_tensors="pt")
|
37 |
+
with torch.no_grad():
|
38 |
+
outputs = model(**inputs)
|
39 |
+
logits = outputs.logits
|
40 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
41 |
+
label = torch.argmax(probabilities, dim=1).item()
|
42 |
+
score = probabilities[0][label].item()
|
43 |
+
return {"label": label, "score": score}
|
44 |
+
|
45 |
+
# Initialize session state
|
46 |
+
if 'history' not in st.session_state:
|
47 |
+
st.session_state.history = []
|
48 |
+
|
49 |
+
if 'user_input' not in st.session_state:
|
50 |
+
st.session_state.user_input = ""
|
51 |
+
|
52 |
+
if 'score' not in st.session_state:
|
53 |
+
st.session_state.score = {'label': 'LABEL_0', 'score': 0.0}
|
54 |
+
|
55 |
+
# Streamlit app layout
|
56 |
+
st.title("Chatbot News Search")
|
57 |
+
|
58 |
+
# User input
|
59 |
+
st.session_state.user_input = st.text_input("What news do you want to search?", st.session_state.user_input)
|
60 |
+
|
61 |
+
# Threshold
|
62 |
+
threshold = 0.7
|
63 |
+
|
64 |
+
# Main logic
|
65 |
+
if st.session_state.user_input:
|
66 |
+
st.session_state.history.append(f"User: {st.session_state.user_input}")
|
67 |
+
|
68 |
+
if st.session_state.score['score'] > threshold:
|
69 |
+
query = st.session_state.user_input
|
70 |
+
SERPAPI_API_KEY = os.getenv("SERPAPI_API_KEY") # Use environment variable for SerpAPI key
|
71 |
+
if not SERPAPI_API_KEY:
|
72 |
+
st.error("SerpAPI API key not found. Please set the SERPAPI_API_KEY environment variable.")
|
73 |
+
else:
|
74 |
+
news_results = search_serpapi(query, "New York", SERPAPI_API_KEY)
|
75 |
+
formatted_news = convert_to_md_table(news_results)
|
76 |
+
st.session_state.history.append(f"Chatbot: Here are the latest news results:\n{formatted_news}")
|
77 |
+
|
78 |
+
user_continue = st.radio("Are you okay with this?", ('Y', 'E'))
|
79 |
+
if user_continue == 'E':
|
80 |
+
st.session_state.history.append("User exited the conversation.")
|
81 |
+
else:
|
82 |
+
new_score = call_classifier(query)
|
83 |
+
while new_score['score'] < st.session_state.score['score']:
|
84 |
+
st.session_state.history.append("Run the SerpAPI again.")
|
85 |
+
new_score = call_classifier(query)['score']
|
86 |
+
st.session_state.history.append(f"New score: {new_score}")
|
87 |
+
|
88 |
+
user_continue = st.radio("Are you okay with this?", ('Y', 'E'))
|
89 |
+
if user_continue == 'E':
|
90 |
+
st.session_state.history.append("User exited the conversation.")
|
91 |
+
break
|
92 |
+
|
93 |
+
st.session_state.score = new_score
|
94 |
+
|
95 |
+
else:
|
96 |
+
st.session_state.history.append(f"Chatbot: Current score: {st.session_state.score['score']}")
|
97 |
+
st.session_state.user_input = st.text_input("Please provide more information to refine the news search:", st.session_state.user_input)
|
98 |
+
st.session_state.score = call_classifier(st.session_state.user_input)
|
99 |
+
st.session_state.history.append(f"New score: {st.session_state.score['score']}")
|
100 |
+
|
101 |
+
# Display chat history
|
102 |
+
for message in st.session_state.history:
|
103 |
+
st.write(message)
|
104 |
+
|
105 |
+
st.write("If you want to finish the conversation, please enter 'exit'.")
|