File size: 4,201 Bytes
1f9a45b
 
 
 
 
 
 
 
5f0ae39
1f9a45b
 
 
 
 
 
 
 
5f0ae39
2c867d4
1f9a45b
0bb757b
 
1f9a45b
 
2c867d4
1f9a45b
0bb757b
 
1f9a45b
 
2c867d4
1f9a45b
0bb757b
 
1f9a45b
 
2c867d4
1f9a45b
0bb757b
 
1f9a45b
 
2c867d4
 
0bb757b
2c867d4
0bb757b
2c867d4
 
 
5f0ae39
2c867d4
0bb757b
2c867d4
1f9a45b
0bb757b
1f9a45b
6027458
0bb757b
1f9a45b
2c867d4
 
 
 
1f9a45b
6027458
0bb757b
1f9a45b
 
0bb757b
 
1f9a45b
 
 
0bb757b
 
 
5f0ae39
 
 
 
 
0bb757b
1f9a45b
 
 
 
1408ebf
 
2c867d4
 
 
 
 
 
 
 
 
1408ebf
 
 
2c867d4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import requests
from io import BytesIO
import threading
import time

# Load models and processor only once using Streamlit session state
if 'models_loaded' not in st.session_state:
    st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
    st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
    st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
    st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
    st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
    st.session_state.models_loaded = True

# Define image processing and classification functions
def topwear(encoding):
    outputs = st.session_state.top_wear_model(**encoding)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    return st.session_state.top_wear_model.config.id2label[predicted_class_idx]

def patterns(encoding):
    outputs = st.session_state.pattern_model(**encoding)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    return st.session_state.pattern_model.config.id2label[predicted_class_idx]

def prints(encoding):
    outputs = st.session_state.print_model(**encoding)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    return st.session_state.print_model.config.id2label[predicted_class_idx]

def sleevelengths(encoding):
    outputs = st.session_state.sleeve_length_model(**encoding)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]

def imageprocessing(image):
    encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt")
    return encoding

# Define the function that will be used in each thread
def call_model(func, encoding, results, index):
    results[index] = func(encoding)

# Run all models in parallel
def pipes(image):
    # Process the image once and reuse the encoding
    encoding = imageprocessing(image)

    # Prepare a list to store the results from each thread
    results = [None] * 4

    # Create threads for each function call
    threads = [
        threading.Thread(target=call_model, args=(topwear, encoding, results, 0)),
        threading.Thread(target=call_model, args=(patterns, encoding, results, 1)),
        threading.Thread(target=call_model, args=(prints, encoding, results, 2)),
        threading.Thread(target=call_model, args=(sleevelengths, encoding, results, 3)),
    ]

    # Start all threads
    for thread in threads:
        thread.start()

    # Wait for all threads to finish
    for thread in threads:
        thread.join()

    # Combine the results into a dictionary
    dicts = {
        "top": results[0],
        "pattern": results[1],
        "print": results[2],
        "sleeve_length": results[3]
    }

    return dicts

# Streamlit app UI
st.title("Clothing Classification Pipeline")

url = st.text_input("Paste image URL here...")
if url:
    try:
        response = requests.get(url)
        if response.status_code == 200:
            image = Image.open(BytesIO(response.content))
            st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
            
            start_time = time.time()

            result = pipes(image)
            st.write("Classification Results (JSON):")
            st.json(result)  # Display results in JSON format
            st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
        else:
            st.error("Failed to load image from URL. Please check the URL.")
    except Exception as e:
        st.error(f"Error processing the image: {str(e)}")