kzz1027 commited on
Commit
2cbca4c
1 Parent(s): 483a6d7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from typing import List, Dict, Any
4
+ from serpapi import GoogleSearch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ import torch
7
+
8
+ def search_serpapi(query: str, loc: str, api_key: str) -> List[Dict[str, Any]]:
9
+ """
10
+ Search using SerpAPI for the given query and return the results.
11
+ """
12
+ try:
13
+ search = GoogleSearch({
14
+ "q": query,
15
+ "location": loc,
16
+ "api_key": api_key
17
+ })
18
+ results = search.get_dict()
19
+ return results.get("organic_results", [])
20
+ except Exception as e:
21
+ raise Exception(f"An error occurred: {e}")
22
+
23
+ def convert_to_md_table(data):
24
+ md_table = "| Title | Link |\n| :--- | :--- |\n"
25
+ for item in data:
26
+ title = item['title']
27
+ link = item['link']
28
+ md_table += f"| {title} | [Link]({link}) |\n"
29
+ return md_table
30
+
31
+ # Load model directly
32
+ tokenizer = AutoTokenizer.from_pretrained("jy46604790/Fake-News-Bert-Detect")
33
+ model = AutoModelForSequenceClassification.from_pretrained("jy46604790/Fake-News-Bert-Detect")
34
+
35
+ def call_classifier(text: str):
36
+ inputs = tokenizer(text, return_tensors="pt")
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
39
+ logits = outputs.logits
40
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
41
+ label = torch.argmax(probabilities, dim=1).item()
42
+ score = probabilities[0][label].item()
43
+ return {"label": label, "score": score}
44
+
45
+ # Initialize session state
46
+ if 'history' not in st.session_state:
47
+ st.session_state.history = []
48
+
49
+ if 'user_input' not in st.session_state:
50
+ st.session_state.user_input = ""
51
+
52
+ if 'score' not in st.session_state:
53
+ st.session_state.score = {'label': 'LABEL_0', 'score': 0.0}
54
+
55
+ # Streamlit app layout
56
+ st.title("Chatbot News Search")
57
+
58
+ # User input
59
+ st.session_state.user_input = st.text_input("What news do you want to search?", st.session_state.user_input)
60
+
61
+ # Threshold
62
+ threshold = 0.7
63
+
64
+ # Main logic
65
+ if st.session_state.user_input:
66
+ st.session_state.history.append(f"User: {st.session_state.user_input}")
67
+
68
+ if st.session_state.score['score'] > threshold:
69
+ query = st.session_state.user_input
70
+ SERPAPI_API_KEY = os.getenv("SERPAPI_API_KEY") # Use environment variable for SerpAPI key
71
+ if not SERPAPI_API_KEY:
72
+ st.error("SerpAPI API key not found. Please set the SERPAPI_API_KEY environment variable.")
73
+ else:
74
+ news_results = search_serpapi(query, "New York", SERPAPI_API_KEY)
75
+ formatted_news = convert_to_md_table(news_results)
76
+ st.session_state.history.append(f"Chatbot: Here are the latest news results:\n{formatted_news}")
77
+
78
+ user_continue = st.radio("Are you okay with this?", ('Y', 'E'))
79
+ if user_continue == 'E':
80
+ st.session_state.history.append("User exited the conversation.")
81
+ else:
82
+ new_score = call_classifier(query)
83
+ while new_score['score'] < st.session_state.score['score']:
84
+ st.session_state.history.append("Run the SerpAPI again.")
85
+ new_score = call_classifier(query)['score']
86
+ st.session_state.history.append(f"New score: {new_score}")
87
+
88
+ user_continue = st.radio("Are you okay with this?", ('Y', 'E'))
89
+ if user_continue == 'E':
90
+ st.session_state.history.append("User exited the conversation.")
91
+ break
92
+
93
+ st.session_state.score = new_score
94
+
95
+ else:
96
+ st.session_state.history.append(f"Chatbot: Current score: {st.session_state.score['score']}")
97
+ st.session_state.user_input = st.text_input("Please provide more information to refine the news search:", st.session_state.user_input)
98
+ st.session_state.score = call_classifier(st.session_state.user_input)
99
+ st.session_state.history.append(f"New score: {st.session_state.score['score']}")
100
+
101
+ # Display chat history
102
+ for message in st.session_state.history:
103
+ st.write(message)
104
+
105
+ st.write("If you want to finish the conversation, please enter 'exit'.")