top_wear / app.py
vishalkatheriya18's picture
Update app.py
2c867d4 verified
raw
history blame
4.2 kB
import streamlit as st
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import requests
from io import BytesIO
import threading
import time
# Load models and processor only once using Streamlit session state
if 'models_loaded' not in st.session_state:
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
st.session_state.models_loaded = True
# Define image processing and classification functions
def topwear(encoding):
outputs = st.session_state.top_wear_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
def patterns(encoding):
outputs = st.session_state.pattern_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return st.session_state.pattern_model.config.id2label[predicted_class_idx]
def prints(encoding):
outputs = st.session_state.print_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return st.session_state.print_model.config.id2label[predicted_class_idx]
def sleevelengths(encoding):
outputs = st.session_state.sleeve_length_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
def imageprocessing(image):
encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt")
return encoding
# Define the function that will be used in each thread
def call_model(func, encoding, results, index):
results[index] = func(encoding)
# Run all models in parallel
def pipes(image):
# Process the image once and reuse the encoding
encoding = imageprocessing(image)
# Prepare a list to store the results from each thread
results = [None] * 4
# Create threads for each function call
threads = [
threading.Thread(target=call_model, args=(topwear, encoding, results, 0)),
threading.Thread(target=call_model, args=(patterns, encoding, results, 1)),
threading.Thread(target=call_model, args=(prints, encoding, results, 2)),
threading.Thread(target=call_model, args=(sleevelengths, encoding, results, 3)),
]
# Start all threads
for thread in threads:
thread.start()
# Wait for all threads to finish
for thread in threads:
thread.join()
# Combine the results into a dictionary
dicts = {
"top": results[0],
"pattern": results[1],
"print": results[2],
"sleeve_length": results[3]
}
return dicts
# Streamlit app UI
st.title("Clothing Classification Pipeline")
url = st.text_input("Paste image URL here...")
if url:
try:
response = requests.get(url)
if response.status_code == 200:
image = Image.open(BytesIO(response.content))
st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
start_time = time.time()
result = pipes(image)
st.write("Classification Results (JSON):")
st.json(result) # Display results in JSON format
st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
else:
st.error("Failed to load image from URL. Please check the URL.")
except Exception as e:
st.error(f"Error processing the image: {str(e)}")