lfashionnlp / fast_api.py
DINGOLANI's picture
Rename api.py to fast_api.py
139f8c4 verified
from fastapi import FastAPI, HTTPException
from transformers import pipeline
import uvicorn
import streamlit as st
# Load trained model
model_name = "DINGOLANI/distilbert-ner-v2"
try:
nlp_ner = pipeline("token-classification", model=model_name, tokenizer=model_name)
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
# Corrected label mapping based on expected training labels
label_map = {
"LABEL_1": "B-BRAND",
"LABEL_2": "I-BRAND",
"LABEL_3": "B-CATEGORY",
"LABEL_4": "I-CATEGORY",
"LABEL_5": "B-GENDER",
"LABEL_6": "B-PRICE",
"LABEL_7": "I-PRICE"
}
entity_filter = {
"B-BRAND": "BRAND",
"I-BRAND": "BRAND",
"B-CATEGORY": "CATEGORY",
"I-CATEGORY": "CATEGORY",
"B-GENDER": "GENDER",
"B-PRICE": "PRICE",
"I-PRICE": "PRICE"
}
app = FastAPI()
@app.get("/")
def home():
return {"message": "NER API is running!"}
@app.post("/predict/")
def predict(query: str):
try:
result = nlp_ner(query)
for label in result:
label["score"] = float(label["score"])
print("RAW MODEL OUTPUT:", result)
structured_output = {}
prev_label = None
prev_word = None
for label in result:
entity_bio = label_map.get(label.get("entity"))
entity = entity_filter.get(entity_bio)
if entity:
word = label["word"]
if word.startswith("##"):
if prev_label == entity and prev_word:
structured_output[entity][-1] += word[2:]
else:
structured_output.setdefault(entity, []).append(word[2:])
else:
structured_output.setdefault(entity, []).append(word)
prev_label = entity
prev_word = word
return {
"query": query,
"raw_output": result,
"structured_output": structured_output
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {e}")
# πŸš€ Streamlit Frontend
def main():
st.set_page_config(page_title="Luxury Fashion NER", layout="wide")
st.title("πŸ‘œ Luxury Fashion Entity Extractor")
st.write("Enter a text query and extract structured entities like **Brand, Category, Gender, and Price.**")
query = st.text_input("Enter Query:", "Gucci handbags for women under $5000")
if st.button("Analyze"):
response = predict(query)
col1, col2 = st.columns(2)
with col1:
st.subheader("πŸ” Structured Output")
for key, value in response["structured_output"].items():
st.write(f"**{key}:** {', '.join(value)}")
with col2:
st.subheader("πŸ›  Raw Model Output")
st.json(response["raw_output"])
if __name__ == "__main__":
main()