Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from transformers import pipeline | |
import uvicorn | |
import streamlit as st | |
# Load trained model | |
model_name = "DINGOLANI/distilbert-ner-v2" | |
try: | |
nlp_ner = pipeline("token-classification", model=model_name, tokenizer=model_name) | |
except Exception as e: | |
raise RuntimeError(f"Failed to load model: {e}") | |
# Corrected label mapping based on expected training labels | |
label_map = { | |
"LABEL_1": "B-BRAND", | |
"LABEL_2": "I-BRAND", | |
"LABEL_3": "B-CATEGORY", | |
"LABEL_4": "I-CATEGORY", | |
"LABEL_5": "B-GENDER", | |
"LABEL_6": "B-PRICE", | |
"LABEL_7": "I-PRICE" | |
} | |
entity_filter = { | |
"B-BRAND": "BRAND", | |
"I-BRAND": "BRAND", | |
"B-CATEGORY": "CATEGORY", | |
"I-CATEGORY": "CATEGORY", | |
"B-GENDER": "GENDER", | |
"B-PRICE": "PRICE", | |
"I-PRICE": "PRICE" | |
} | |
app = FastAPI() | |
def home(): | |
return {"message": "NER API is running!"} | |
def predict(query: str): | |
try: | |
result = nlp_ner(query) | |
for label in result: | |
label["score"] = float(label["score"]) | |
print("RAW MODEL OUTPUT:", result) | |
structured_output = {} | |
prev_label = None | |
prev_word = None | |
for label in result: | |
entity_bio = label_map.get(label.get("entity")) | |
entity = entity_filter.get(entity_bio) | |
if entity: | |
word = label["word"] | |
if word.startswith("##"): | |
if prev_label == entity and prev_word: | |
structured_output[entity][-1] += word[2:] | |
else: | |
structured_output.setdefault(entity, []).append(word[2:]) | |
else: | |
structured_output.setdefault(entity, []).append(word) | |
prev_label = entity | |
prev_word = word | |
return { | |
"query": query, | |
"raw_output": result, | |
"structured_output": structured_output | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing request: {e}") | |
# π Streamlit Frontend | |
def main(): | |
st.set_page_config(page_title="Luxury Fashion NER", layout="wide") | |
st.title("π Luxury Fashion Entity Extractor") | |
st.write("Enter a text query and extract structured entities like **Brand, Category, Gender, and Price.**") | |
query = st.text_input("Enter Query:", "Gucci handbags for women under $5000") | |
if st.button("Analyze"): | |
response = predict(query) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("π Structured Output") | |
for key, value in response["structured_output"].items(): | |
st.write(f"**{key}:** {', '.join(value)}") | |
with col2: | |
st.subheader("π Raw Model Output") | |
st.json(response["raw_output"]) | |
if __name__ == "__main__": | |
main() | |