import os import streamlit as st from typing import List, Dict, Any from serpapi import GoogleSearch from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch def search_serpapi(query: str, loc: str, api_key: str) -> List[Dict[str, Any]]: """ Search using SerpAPI for the given query and return the results. """ try: search = GoogleSearch({ "q": query, "location": loc, "api_key": api_key }) results = search.get_dict() return results.get("organic_results", []) except Exception as e: raise Exception(f"An error occurred: {e}") def convert_to_md_table(data): md_table = "| Title | Link |\n| :--- | :--- |\n" for item in data: title = item['title'] link = item['link'] md_table += f"| {title} | [Link]({link}) |\n" return md_table # Load model directly tokenizer = AutoTokenizer.from_pretrained("jy46604790/Fake-News-Bert-Detect") model = AutoModelForSequenceClassification.from_pretrained("jy46604790/Fake-News-Bert-Detect") def call_classifier(text: str): inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=-1) label = torch.argmax(probabilities, dim=1).item() score = probabilities[0][label].item() return {"label": label, "score": score} # Initialize session state if 'history' not in st.session_state: st.session_state.history = [] if 'user_input' not in st.session_state: st.session_state.user_input = "" if 'score' not in st.session_state: st.session_state.score = {'label': 'LABEL_0', 'score': 0.0} # Streamlit app layout st.title("Chatbot News Search") # User input st.session_state.user_input = st.text_input("What news do you want to search?", st.session_state.user_input) # Threshold threshold = 0.7 # Main logic if st.session_state.user_input: st.session_state.history.append(f"User: {st.session_state.user_input}") if st.session_state.score['score'] > threshold: query = st.session_state.user_input SERPAPI_API_KEY = os.getenv("SERPAPI_API_KEY") # Use environment variable for SerpAPI key if not SERPAPI_API_KEY: st.error("SerpAPI API key not found. Please set the SERPAPI_API_KEY environment variable.") else: news_results = search_serpapi(query, "New York", SERPAPI_API_KEY) formatted_news = convert_to_md_table(news_results) st.session_state.history.append(f"Chatbot: Here are the latest news results:\n{formatted_news}") user_continue = st.radio("Are you okay with this?", ('Y', 'E')) if user_continue == 'E': st.session_state.history.append("User exited the conversation.") else: new_score = call_classifier(query) while new_score['score'] < st.session_state.score['score']: st.session_state.history.append("Run the SerpAPI again.") new_score = call_classifier(query)['score'] st.session_state.history.append(f"New score: {new_score}") user_continue = st.radio("Are you okay with this?", ('Y', 'E')) if user_continue == 'E': st.session_state.history.append("User exited the conversation.") break st.session_state.score = new_score else: st.session_state.history.append(f"Chatbot: Current score: {st.session_state.score['score']}") st.session_state.user_input = st.text_input("Please provide more information to refine the news search:", st.session_state.user_input) st.session_state.score = call_classifier(st.session_state.user_input) st.session_state.history.append(f"New score: {st.session_state.score['score']}") # Display chat history for message in st.session_state.history: st.write(message) st.write("If you want to finish the conversation, please enter 'exit'.")