kritsadaK's picture
Initial commit
f231914
raw
history blame
6.01 kB
import warnings
import torchvision
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
from sklearn.metrics.pairwise import cosine_similarity
import streamlit as st
# Suppress torchvision beta warnings
torchvision.disable_beta_transforms_warning()
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
# Load tokenizer and model with error handling for compatibility
try:
tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased", use_fast=False)
model = AutoModelForMaskedLM.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
except Exception:
st.warning("Switching to xlm-roberta-base model due to compatibility issues.")
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", use_fast=False)
model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base")
# Initialize the fill-mask pipeline
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, framework="pt")
# Function to generate embeddings
def get_embedding(text):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
return outputs.logits[:, 0, :].cpu().numpy()
# Streamlit app setup
st.title("Thai Full Sentence Similarity App")
st.write("""
## using Thai Law nlp dataset""")
st.write("""
### How This App Works
This app uses a mask-filling model to predict possible words or phrases that could fill in the `<mask>` token in a given sentence. It then calculates the similarity of each prediction with the original sentence to determine the most contextually appropriate completion.
### Example Sentence
In this example, we have the following sentence in Thai with a `<mask>` token:
- **Input**: `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน <mask> เพื่อสัมผัสธรรมชาติ"`
- **Translation**: "Many tourists choose to visit `<mask>` to experience nature."
The `<mask>` token represents a location popular for its natural beauty.
### Potential Predictions
Here are some possible predictions the model might generate for `<mask>`:
1. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เชียงใหม่ เพื่อสัมผัสธรรมชาติ"` - Chiang Mai
2. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เขาใหญ่ เพื่อสัมผัสธรรมชาติ"` - Khao Yai
3. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เกาะสมุย เพื่อสัมผัสธรรมชาติ"` - Koh Samui
4. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน ภูเก็ต เพื่อสัมผัสธรรมชาติ"` - Phuket
### Results Table
For each prediction, the app calculates:
- **Similarity Score**: Indicates how similar the predicted sentence is to the original input.
- **Model Score**: Represents the model's confidence in the predicted word for `<mask>`.
### Most Similar Prediction
The app will display the most contextually similar prediction based on the similarity score. For example:
- **Most Similar Prediction**: `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เชียงใหม่ เพื่อสัมผัสธรรมชาติ"`
- **Similarity Score**: 0.89
- **Model Score**: 0.16
Feel free to enter your own sentence with `<mask>` and explore the predictions!
""")
# User input box
st.subheader("Input Text")
input_text = st.text_input("Enter a sentence with `<mask>` to find similar predictions:", "เมนูโปรดของฉันคือ <mask> ที่ทำจากวัตถุดิบสดใหม่")
# Ensure the input includes a `<mask>`
if "<mask>" not in input_text:
input_text += " <mask>"
st.warning("`<mask>` token was missing in your input. It has been added automatically.")
# Process the input when available
if input_text:
st.write(f"Input Text: {input_text}")
# Generate baseline embedding (removing `<mask>` to get the full sentence)
baseline_text = input_text.replace("<mask>", "")
input_embedding = get_embedding(baseline_text)
# Generate mask predictions and calculate similarity with the full sentences
similarity_results = []
try:
result = pipe(input_text)
for r in result:
prediction_text = r.get('sequence', '')
if prediction_text:
prediction_embedding = get_embedding(prediction_text)
similarity = cosine_similarity(input_embedding, prediction_embedding)[0][0]
similarity_results.append({
"Prediction": prediction_text,
"Similarity Score": similarity,
"Model Score": r['score']
})
# Convert results to DataFrame for easy sorting and display
df_results = pd.DataFrame(similarity_results).sort_values(by="Similarity Score", ascending=False)
# Display all predictions sorted by similarity score
st.subheader("All Predictions Sorted by Similarity")
st.dataframe(df_results)
# Display the most similar prediction
most_similar = df_results.iloc[0]
st.subheader("Most Similar Prediction")
st.write(f"**Prediction**: {most_similar['Prediction']}")
st.write(f"**Similarity Score**: {most_similar['Similarity Score']:.4f}")
st.write(f"**Model Score**: {most_similar['Model Score']:.4f}")
except KeyError:
st.error("Unexpected model output structure; unable to retrieve predictions.")