import os from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification import streamlit as st import torch import torch.nn.functional as F import matplotlib.pyplot as plt # Load Hugging Face token from environment HF_TOKEN = os.getenv("HF_TOKEN") model_path = "dejanseo/DEJAN-Taxonomy-Classifier" # Load the model and tokenizer using the token tokenizer = DebertaV2Tokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN) model = DebertaV2ForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN) # LABEL_MAPPING (from model index to numeric ID) and corresponding category names LABEL_MAPPING = { 1: 0, 8: 1, 111: 2, 141: 3, 166: 4, 222: 5, 412: 6, 436: 7, 469: 8, 536: 9, 537: 10, 632: 11, 772: 12, 783: 13, 888: 14, 922: 15, 988: 16, 1239: 17, 2092: 18, 5181: 19, 5605: 20 } CATEGORY_NAMES = { 1: "Animals & Pet Supplies", 8: "Arts & Entertainment", 111: "Business & Industrial", 141: "Cameras & Optics", 166: "Apparel & Accessories", 222: "Electronics", 412: "Food, Beverages & Tobacco", 436: "Furniture", 469: "Health & Beauty", 536: "Home & Garden", 537: "Baby & Toddler", 632: "Hardware", 772: "Mature", 783: "Media", 888: "Vehicles & Parts", 922: "Office Supplies", 988: "Sporting Goods", 1239: "Toys & Games", 2092: "Software", 5181: "Luggage & Bags", 5605: "Religious & Ceremonial" } # Reverse mapping for model output index to text label INDEX_TO_CATEGORY = {v: f"[{k}] {CATEGORY_NAMES[k]}" for k, v in LABEL_MAPPING.items()} # Set Streamlit app title st.title("Google Taxonomy Classifier by DEJAN") st.write("Enter text in the input box, and the model will classify it into one of the 21 top level categories. This demo showcases early model capability while the full 5000+ label model is undergoing extensive training.") st.write("Works for product descriptions, search queries, articles, social media posts and broadly web text of any style. Suitable for classification pipelines of millions of queries.") # Input text box input_text = st.text_area("Enter text for classification:") # Inference function def classify_text(text): if not text.strip(): return None # Tokenize and encode input text inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) # Get model predictions with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Convert logits to probabilities using softmax probabilities = F.softmax(logits, dim=-1).squeeze().tolist() return probabilities # Display results when text is entered if st.button("Classify"): if input_text.strip(): st.write("Processing...") # Classify the input text probabilities = classify_text(input_text) if probabilities: # Map probabilities to categories mapped_probs = {INDEX_TO_CATEGORY[idx]: prob for idx, prob in enumerate(probabilities)} # Sort categories by probability in descending order sorted_categories = sorted(mapped_probs.items(), key=lambda x: x[1], reverse=True) categories = [item[0] for item in sorted_categories] values = [item[1] for item in sorted_categories] # Create horizontal bar chart fig, ax = plt.subplots(figsize=(10, 6)) ax.barh(categories, values) ax.set_xlabel("Probability") ax.set_ylabel("Category") ax.set_title("Classification Probabilities") ax.invert_yaxis() # Ensure highest probability is at the top ax.set_xlim(0, 1) # Set the x-axis range to 0-1 for probabilities st.pyplot(fig) # Additional information at the end st.divider() st.markdown(""" Interested in using this in an automated pipeline for bulk link prediction? Please [book an appointment](https://dejanmarketing.com/conference/) to discuss your needs. """) else: st.error("Could not classify the text. Please try again.") else: st.warning("Please enter some text for classification.")