Spaces:
Sleeping
Sleeping
import warnings | |
import torchvision | |
import torch | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline | |
from sklearn.metrics.pairwise import cosine_similarity | |
import streamlit as st | |
# Suppress torchvision beta warnings | |
torchvision.disable_beta_transforms_warning() | |
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision") | |
# Load tokenizer and model with error handling for compatibility | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased", use_fast=False) | |
model = AutoModelForMaskedLM.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased") | |
model_name = "airesearch/wangchanberta-base-att-spm-uncased" | |
except Exception: | |
st.warning("Switching to xlm-roberta-base model due to compatibility issues.") | |
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") | |
model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base") | |
model_name = "xlm-roberta-base" | |
# Initialize the fill-mask pipeline | |
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, framework="pt") | |
# Function to generate embeddings | |
def get_embedding(text): | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
return outputs.logits[:, 0, :].cpu().numpy() | |
# Streamlit app setup | |
st.title("Thai Full Sentence Similarity App") | |
st.write(""" | |
### How This App Works | |
This app uses a mask-filling model to predict possible words or phrases that could fill in the `<mask>` token in a given sentence. It then calculates the similarity of each prediction with the original sentence to determine the most contextually appropriate completion. | |
### Example Sentence | |
In this example, we have the following sentence in Thai with a `<mask>` token: | |
- **Input**: `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน <mask> เพื่อสัมผัสธรรมชาติ"` | |
- **Translation**: "Many tourists choose to visit `<mask>` to experience nature." | |
The `<mask>` token represents a location popular for its natural beauty. | |
### Potential Predictions | |
Here are some possible predictions the model might generate for `<mask>`: | |
1. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เชียงใหม่ เพื่อสัมผัสธรรมชาติ"` - Chiang Mai | |
2. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เขาใหญ่ เพื่อสัมผัสธรรมชาติ"` - Khao Yai | |
3. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เกาะสมุย เพื่อสัมผัสธรรมชาติ"` - Koh Samui | |
4. `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน ภูเก็ต เพื่อสัมผัสธรรมชาติ"` - Phuket | |
### Results Table | |
For each prediction, the app calculates: | |
- **Similarity Score**: Indicates how similar the predicted sentence is to the original input. | |
- **Model Score**: Represents the model's confidence in the predicted word for `<mask>`. | |
### Most Similar Prediction | |
The app will display the most contextually similar prediction based on the similarity score. For example: | |
- **Most Similar Prediction**: `"นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน เชียงใหม่ เพื่อสัมผัสธรรมชาติ"` | |
- **Similarity Score**: 0.89 | |
- **Model Score**: 0.16 | |
Feel free to enter your own sentence with `<mask>` and explore the predictions! | |
""") | |
# User input box | |
st.subheader("Input Text") | |
input_text = st.text_input("Enter a sentence with `<mask>` to find similar predictions:", "ผู้ใช้งานท่าอากาศยานนานาชาติ <mask> มีกว่าสามล้านคน") | |
# Ensure the input includes a `<mask>` | |
if "<mask>" not in input_text: | |
input_text += " <mask>" | |
st.warning("`<mask>` token was missing in your input. It has been added automatically.") | |
# Process the input when available | |
if input_text: | |
st.write(f"Input Text: {input_text}") | |
# Generate baseline embedding (removing `<mask>` to get the full sentence) | |
baseline_text = input_text.replace("<mask>", "") | |
input_embedding = get_embedding(baseline_text) | |
# Generate mask predictions and calculate similarity with the full sentences | |
similarity_results = [] | |
try: | |
result = pipe(input_text) | |
for r in result: | |
# Adjust based on observed output structure | |
prediction_text = r.get('sequence', '') | |
# Only proceed if we have a valid prediction text | |
if prediction_text: | |
prediction_embedding = get_embedding(prediction_text) | |
similarity = cosine_similarity(input_embedding, prediction_embedding)[0][0] | |
similarity_results.append({ | |
"Prediction": prediction_text, | |
"Similarity Score": similarity, | |
"Model Score": r['score'] | |
}) | |
# Convert results to DataFrame for easy sorting and display | |
df_results = pd.DataFrame(similarity_results).sort_values(by="Similarity Score", ascending=False) | |
# Display all predictions sorted by similarity score | |
st.subheader("All Predictions Sorted by Similarity") | |
st.dataframe(df_results) | |
# Display the most similar prediction | |
most_similar = df_results.iloc[0] | |
st.subheader("Most Similar Prediction") | |
st.write(f"**Prediction**: {most_similar['Prediction']}") | |
st.write(f"**Similarity Score**: {most_similar['Similarity Score']:.4f}") | |
st.write(f"**Model Score**: {most_similar['Model Score']:.4f}") | |
except KeyError: | |
st.error("Unexpected model output structure; unable to retrieve predictions.") | |