File size: 3,378 Bytes
4dc73cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
import requests
import torch
from transformers import AutoTokenizer, AutoModel
import xml.etree.ElementTree as ET

# Load SciBERT pre-trained model and tokenizer
model_name = "allenai/scibert_scivocab_uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

def calculate_similarity(claim, document):
    if not claim or not document:
        return 0.0
    # Tokenize claim and document
    inputs = tokenizer.encode_plus(claim, document, return_tensors='pt', padding=True, truncation=True)
    
    # Generate embeddings for claim
    with torch.no_grad():
        claim_embeddings = model(**inputs)['pooler_output']
    
    # Generate embeddings for document
    inputs_doc = tokenizer.encode_plus(document, return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
        document_embeddings = model(**inputs_doc)['pooler_output']
    
    # Compute cosine similarity between embeddings
    similarity = torch.cosine_similarity(claim_embeddings, document_embeddings).item()
    
    return similarity

def search_arxiv(query, max_results=3):
    base_url = "http://export.arxiv.org/api/query?"
    query = f"search_query=all:{query}&start=0&max_results={max_results}&sortBy=relevance&sortOrder=descending"
    
    try:
        response = requests.get(base_url + query)
        if response.status_code == 200:
            data = response.content

            # Parse the XML response
            root = ET.fromstring(data)

            search_results = []
            for entry in root.findall("{http://www.w3.org/2005/Atom}entry"):
                result = {}

                # Extract information from each entry
                result["title"] = entry.find("{http://www.w3.org/2005/Atom}title").text
                result["abstract"] = entry.find("{http://www.w3.org/2005/Atom}summary").text
                result["link"] = entry.find("{http://www.w3.org/2005/Atom}link[@title='pdf']").attrib["href"]

                authors = []
                for author in entry.findall("{http://www.w3.org/2005/Atom}author"):
                    authors.append(author.find("{http://www.w3.org/2005/Atom}name").text)
                result["authors"] = authors

                search_results.append(result)

            return search_results
    except:
        return None

def search_papers(user_input):
    # Use the desired search function, e.g., search_arxiv
    search_results = search_arxiv(user_input)
    return search_results

st.title('The Substantiator')

user_input = st.text_input('Input your claim')

if st.button('Substantiate'):
    search_results = search_papers(user_input)
    if search_results is not None and len(search_results) > 0:
        with st.spinner('Searching for relevant research papers...'):
            for result in search_results[:3]:
                st.write(f"<a href='javascript:void(0)' onclick='window.open(\"{result['link']}\", \"_blank\");return false;'>{result['title']}</a>", unsafe_allow_html=True)
                st.write(result["abstract"])
                st.write("Authors: ", ", ".join(result["authors"]))
                similarity = calculate_similarity(user_input, result["abstract"])
                st.write("Similarity Score: ", similarity)
                st.write("-----")
    else:
        st.write("No results found.")