from fastapi import FastAPI, HTTPException from transformers import pipeline import uvicorn import streamlit as st # Load trained model model_name = "DINGOLANI/distilbert-ner-v2" try: nlp_ner = pipeline("token-classification", model=model_name, tokenizer=model_name) except Exception as e: raise RuntimeError(f"Failed to load model: {e}") # Corrected label mapping based on expected training labels label_map = { "LABEL_1": "B-BRAND", "LABEL_2": "I-BRAND", "LABEL_3": "B-CATEGORY", "LABEL_4": "I-CATEGORY", "LABEL_5": "B-GENDER", "LABEL_6": "B-PRICE", "LABEL_7": "I-PRICE" } entity_filter = { "B-BRAND": "BRAND", "I-BRAND": "BRAND", "B-CATEGORY": "CATEGORY", "I-CATEGORY": "CATEGORY", "B-GENDER": "GENDER", "B-PRICE": "PRICE", "I-PRICE": "PRICE" } app = FastAPI() @app.get("/") def home(): return {"message": "NER API is running!"} @app.post("/predict/") def predict(query: str): try: result = nlp_ner(query) for label in result: label["score"] = float(label["score"]) print("RAW MODEL OUTPUT:", result) structured_output = {} prev_label = None prev_word = None for label in result: entity_bio = label_map.get(label.get("entity")) entity = entity_filter.get(entity_bio) if entity: word = label["word"] if word.startswith("##"): if prev_label == entity and prev_word: structured_output[entity][-1] += word[2:] else: structured_output.setdefault(entity, []).append(word[2:]) else: structured_output.setdefault(entity, []).append(word) prev_label = entity prev_word = word return { "query": query, "raw_output": result, "structured_output": structured_output } except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing request: {e}") # 🚀 Streamlit Frontend def main(): st.set_page_config(page_title="Luxury Fashion NER", layout="wide") st.title("👜 Luxury Fashion Entity Extractor") st.write("Enter a text query and extract structured entities like **Brand, Category, Gender, and Price.**") query = st.text_input("Enter Query:", "Gucci handbags for women under $5000") if st.button("Analyze"): response = predict(query) col1, col2 = st.columns(2) with col1: st.subheader("🔍 Structured Output") for key, value in response["structured_output"].items(): st.write(f"**{key}:** {', '.join(value)}") with col2: st.subheader("🛠 Raw Model Output") st.json(response["raw_output"]) if __name__ == "__main__": main()