top_wear / app.py
vishalkatheriya18's picture
Update app.py
1408ebf verified
raw
history blame
3.82 kB
import streamlit as st
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import requests
from io import BytesIO
import threading
import time
# Load models and processor only once using Streamlit session state
if 'models_loaded' not in st.session_state:
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
st.session_state.models_loaded = True
# Define image processing and classification functions
def imageprocessing(image):
encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt")
return encoding
def topwear(encoding):
outputs = st.session_state.top_wear_model(**encoding)
predicted_class_idx = outputs.logits.argmax(-1).item()
st.write(st.session_state.top_wear_model.config.id2label[predicted_class_idx])
return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
def patterns(encoding):
outputs = st.session_state.pattern_model(**encoding)
predicted_class_idx = outputs.logits.argmax(-1).item()
return st.session_state.pattern_model.config.id2label[predicted_class_idx]
def prints(encoding):
outputs = st.session_state.print_model(**encoding)
predicted_class_idx = outputs.logits.argmax(-1).item()
return st.session_state.print_model.config.id2label[predicted_class_idx]
def sleevelengths(encoding):
outputs = st.session_state.sleeve_length_model(**encoding)
predicted_class_idx = outputs.logits.argmax(-1).item()
return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
# Run all models in parallel
def pipes(image):
encoding = imageprocessing(image)
results = [None] * 4
def update_results(index, func):
results[index] = func(encoding)
threads = [
threading.Thread(target=update_results, args=(0, topwear)),
threading.Thread(target=update_results, args=(1, patterns)),
threading.Thread(target=update_results, args=(2, prints)),
threading.Thread(target=update_results, args=(3, sleevelengths)),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
result_dict = {
"topwear": results[0],
"pattern": results[1],
"print": results[2],
"sleeve_length": results[3]
}
return result_dict
# Streamlit app UI
st.title("Clothing Classification Pipeline")
url = st.text_input("Paste image URL here...")
if url:
response = requests.get(url)
if response.status_code == 200:
image = Image.open(BytesIO(response.content))
st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
start_time = time.time()
try:
result = pipes(image)
st.write("Classification Results (JSON):")
st.json(result) # Display results in JSON format
st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
except Exception as e:
st.error(f"Error processing the image: {str(e)}")
else:
st.error("Failed to load image from URL. Please check the URL.")