PatentMatch / app.py
DataRaptor's picture
Upload 2 files
6932c6c
import pandas as pd
import streamlit as st
from infer import USPPPMModel, USPPPMDataset
import torch
@st.cache_resource
def load_model():
model = USPPPMModel('microsoft/deberta-v3-small')
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))
model.eval()
ds = USPPPMDataset(model.tokenizer, 133)
return model, ds
def infer(anchor, target, title):
model, ds = load_model()
d = {
'anchor': anchor,
'target': target,
'title': title,
'label': 0
}
x = ds[d][0]
with torch.no_grad():
y = model(x)
return y.cpu().numpy()[0][0]
@st.cache_data
def get_context():
df = pd.read_csv('./fold-0-train.csv')
l = list(set(list(df['title'].values)))
return l
st.set_page_config(
page_title="PatentMatch",
page_icon="🧊",
layout="centered",
initial_sidebar_state="expanded",
)
# fix sidebar
st.markdown("""
<style>
.css-vk3wp9 {
background-color: rgb(255 255 255);
}
.css-18l0hbk {
padding: 0.34rem 1.2rem !important;
margin: 0.125rem 2rem;
}
.css-nziaof {
padding: 0.34rem 1.2rem !important;
margin: 0.125rem 2rem;
background-color: rgb(181 197 227 / 18%) !important;
}
</style>
""", unsafe_allow_html=True
)
hide_st_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
</style>
"""
st.markdown(hide_st_style, unsafe_allow_html=True)
def app():
st.title("PatentMatch: Patent Semantic Similarity Matcher")
#st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai/<username>/<project_name>?workspace=user-<username>)")
st.markdown(
"""This project is focused on developing a Transformer based NLP model to match phrases
in U.S. patents based on their semantic similarity within a specific
technical domain context. The trained model achieved Pearson correlation coefficient score of 0.745.
[[Source Code]](https://github.com/dataraptor/PatentMatch)
"""
)
st.markdown('---')
# st.selectbox("Select from example",
# [
# "Example 1",
# "Example 2",
# ])
row1_col1, row1_col2, row1_col3 = st.columns(
[0.5, 0.4, 0.4]
)
# with row1_col1:
# frequency = st.selectbox("Section",
# [
# "A: Human Necessities",
# "B: Operations and Transport",
# "C: Chemistry and Metallurgy",
# "D: Textiles",
# "E: Fixed Constructions",
# "F: Mechanical Engineering",
# "G: Physics",
# "H: Electricity",
# "Y: Emerging Cross-Sectional Technologies",
# ])
# with row1_col2:
# class_box = st.selectbox("Class",
# [
# "21",
# "14",
# "23",
# ])
with row1_col1:
l = get_context()
context = st.selectbox("Context", l, l.index('basic electric elements'))
with row1_col2:
anchor = st.text_input("Anchor", "deflect light")
with row1_col3:
target = st.text_input("Target", "bending moment")
if st.button("Predict Scores", type="primary"):
with st.spinner("Predicting scores..."):
score = infer(anchor, target, context)
ss = st.success("Scores predicted successfully!")
score += 2.0
fmt = "{:<.3f}".format(score)
st.subheader(f"Similarity Score: {fmt}")
app()
# Display a footer with links and credits
st.markdown("---")
st.markdown("Built by [Shamim Ahamed](https://www.shamimahamed.com/). Data provided by [Kaggle](https://www.kaggle.com/competitions/us-patent-phrase-to-phrase-matching)")
#st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)")