vishalkatheriya18 commited on
Commit
1f9a45b
·
verified ·
1 Parent(s): 9510ecb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
3
+ from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
+ import threading
7
+ import time
8
+
9
+ # Load models and processor only once using session state
10
+ if 'models_loaded' not in st.session_state:
11
+ # Image processor
12
+ st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
13
+ # Topwear model
14
+ st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
15
+ # Pattern model
16
+ st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
17
+ # Print model
18
+ st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
19
+ # Sleeve length model
20
+ st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
21
+ st.session_state.models_loaded = True
22
+
23
+ # Functions for predictions
24
+ def topwear(encoding):
25
+ outputs = st.session_state.top_wear_model(**encoding)
26
+ predicted_class_idx = outputs.logits.argmax(-1).item()
27
+ return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
28
+
29
+ def patterns(encoding):
30
+ outputs = st.session_state.pattern_model(**encoding)
31
+ predicted_class_idx = outputs.logits.argmax(-1).item()
32
+ return st.session_state.pattern_model.config.id2label[predicted_class_idx]
33
+
34
+ def prints(encoding):
35
+ outputs = st.session_state.print_model(**encoding)
36
+ predicted_class_idx = outputs.logits.argmax(-1).item()
37
+ return st.session_state.print_model.config.id2label[predicted_class_idx]
38
+
39
+ def sleevelengths(encoding):
40
+ outputs = st.session_state.sleeve_length_model(**encoding)
41
+ predicted_class_idx = outputs.logits.argmax(-1).item()
42
+ return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
43
+
44
+ def imageprocessing(url):
45
+ response = requests.get(url)
46
+ image = Image.open(BytesIO(response.content))
47
+ encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt")
48
+ return encoding, image
49
+
50
+ def pipes(imagepath):
51
+ encoding, image = imageprocessing(imagepath)
52
+ # Using threading for faster results
53
+ results = [None] * 4
54
+ threads = [
55
+ threading.Thread(target=lambda: results.__setitem__(0, topwear(encoding))),
56
+ threading.Thread(target=lambda: results.__setitem__(1, patterns(encoding))),
57
+ threading.Thread(target=lambda: results.__setitem__(2, prints(encoding))),
58
+ threading.Thread(target=lambda: results.__setitem__(3, sleevelengths(encoding))),
59
+ ]
60
+ for thread in threads:
61
+ thread.start()
62
+ for thread in threads:
63
+ thread.join()
64
+
65
+ dicts = {"top": results[0], "pattern": results[1], "print": results[2], "sleeve_length": results[3]}
66
+ return dicts, image
67
+
68
+ # Streamlit app UI
69
+ st.title("Clothing Classification Pipeline")
70
+
71
+ image_url = st.text_input("Enter Image URL")
72
+
73
+ if image_url:
74
+ start_time = time.time()
75
+ results, img = pipes(image_url)
76
+ st.image(img.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
77
+
78
+ # Display results
79
+ st.write("Classification Results:")
80
+ st.write(f"Topwear: {results['top']}")
81
+ st.write(f"Pattern: {results['pattern']}")
82
+ st.write(f"Print: {results['print']}")
83
+ st.write(f"Sleeve Length: {results['sleeve_length']}")
84
+
85
+ st.write(f"Time taken: {time.time() - start_time:.2f} seconds")