Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import threading | |
import time | |
# Load models and processor only once using Streamlit session state | |
if 'models_loaded' not in st.session_state: | |
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") | |
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") | |
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb") | |
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print") | |
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length") | |
st.session_state.models_loaded = True | |
# Define image processing and classification functions | |
def topwear(encoding): | |
outputs = st.session_state.top_wear_model(**encoding) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
return st.session_state.top_wear_model.config.id2label[predicted_class_idx] | |
def patterns(encoding): | |
outputs = st.session_state.pattern_model(**encoding) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
return st.session_state.pattern_model.config.id2label[predicted_class_idx] | |
def prints(encoding): | |
outputs = st.session_state.print_model(**encoding) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
return st.session_state.print_model.config.id2label[predicted_class_idx] | |
def sleevelengths(encoding): | |
outputs = st.session_state.sleeve_length_model(**encoding) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx] | |
def imageprocessing(image): | |
encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt") | |
return encoding | |
# Define the function that will be used in each thread | |
def call_model(func, encoding, results, index): | |
results[index] = func(encoding) | |
# Run all models in parallel | |
def pipes(image): | |
# Process the image once and reuse the encoding | |
encoding = imageprocessing(image) | |
# Prepare a list to store the results from each thread | |
results = [None] * 4 | |
# Create threads for each function call | |
threads = [ | |
threading.Thread(target=call_model, args=(topwear, encoding, results, 0)), | |
threading.Thread(target=call_model, args=(patterns, encoding, results, 1)), | |
threading.Thread(target=call_model, args=(prints, encoding, results, 2)), | |
threading.Thread(target=call_model, args=(sleevelengths, encoding, results, 3)), | |
] | |
# Start all threads | |
for thread in threads: | |
thread.start() | |
# Wait for all threads to finish | |
for thread in threads: | |
thread.join() | |
# Combine the results into a dictionary | |
dicts = { | |
"top": results[0], | |
"pattern": results[1], | |
"print": results[2], | |
"sleeve_length": results[3] | |
} | |
return dicts | |
# Streamlit app UI | |
st.title("Clothing Classification Pipeline") | |
url = st.text_input("Paste image URL here...") | |
if url: | |
try: | |
response = requests.get(url) | |
if response.status_code == 200: | |
image = Image.open(BytesIO(response.content)) | |
st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False) | |
start_time = time.time() | |
result = pipes(image) | |
st.write("Classification Results (JSON):") | |
st.json(result) # Display results in JSON format | |
st.write(f"Time taken: {time.time() - start_time:.2f} seconds") | |
else: | |
st.error("Failed to load image from URL. Please check the URL.") | |
except Exception as e: | |
st.error(f"Error processing the image: {str(e)}") | |