Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import threading | |
import time | |
# Load models and processor only once using Streamlit session state | |
if 'models_loaded' not in st.session_state: | |
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") | |
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") | |
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb") | |
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print") | |
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length") | |
st.session_state.models_loaded = True | |
# Define image processing and classification functions | |
def imageprocessing(image): | |
encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt") | |
return encoding | |
def topwear(encoding): | |
outputs = st.session_state.top_wear_model(**encoding) | |
predicted_class_idx = outputs.logits.argmax(-1).item() | |
st.write(st.session_state.top_wear_model.config.id2label[predicted_class_idx]) | |
return st.session_state.top_wear_model.config.id2label[predicted_class_idx] | |
def patterns(encoding): | |
outputs = st.session_state.pattern_model(**encoding) | |
predicted_class_idx = outputs.logits.argmax(-1).item() | |
return st.session_state.pattern_model.config.id2label[predicted_class_idx] | |
def prints(encoding): | |
outputs = st.session_state.print_model(**encoding) | |
predicted_class_idx = outputs.logits.argmax(-1).item() | |
return st.session_state.print_model.config.id2label[predicted_class_idx] | |
def sleevelengths(encoding): | |
outputs = st.session_state.sleeve_length_model(**encoding) | |
predicted_class_idx = outputs.logits.argmax(-1).item() | |
return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx] | |
# Run all models in parallel | |
def pipes(image): | |
encoding = imageprocessing(image) | |
results = [None] * 4 | |
def update_results(index, func): | |
results[index] = func(encoding) | |
threads = [ | |
threading.Thread(target=update_results, args=(0, topwear)), | |
threading.Thread(target=update_results, args=(1, patterns)), | |
threading.Thread(target=update_results, args=(2, prints)), | |
threading.Thread(target=update_results, args=(3, sleevelengths)), | |
] | |
for thread in threads: | |
thread.start() | |
for thread in threads: | |
thread.join() | |
result_dict = { | |
"topwear": results[0], | |
"pattern": results[1], | |
"print": results[2], | |
"sleeve_length": results[3] | |
} | |
return result_dict | |
# Streamlit app UI | |
st.title("Clothing Classification Pipeline") | |
url = st.text_input("Paste image URL here...") | |
if url: | |
response = requests.get(url) | |
if response.status_code == 200: | |
image = Image.open(BytesIO(response.content)) | |
st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False) | |
start_time = time.time() | |
try: | |
result = pipes(image) | |
st.write("Classification Results (JSON):") | |
st.json(result) # Display results in JSON format | |
st.write(f"Time taken: {time.time() - start_time:.2f} seconds") | |
except Exception as e: | |
st.error(f"Error processing the image: {str(e)}") | |
else: | |
st.error("Failed to load image from URL. Please check the URL.") | |