vishalkatheriya18 commited on
Commit
5f0ae39
1 Parent(s): 6027458

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -26
app.py CHANGED
@@ -5,23 +5,23 @@ import requests
5
  from io import BytesIO
6
  import threading
7
  import time
8
- import json
9
 
10
- # Load models and processor only once using session state
11
  if 'models_loaded' not in st.session_state:
12
- # Image processor
13
  st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
14
- # Topwear model
15
  st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
16
- # Pattern model
17
  st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
18
- # Print model
19
  st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
20
- # Sleeve length model
21
  st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
22
  st.session_state.models_loaded = True
23
 
24
- # Functions for predictions
 
 
 
 
 
 
25
  def topwear(encoding):
26
  outputs = st.session_state.top_wear_model(**encoding)
27
  predicted_class_idx = outputs.logits.argmax(-1).item()
@@ -42,15 +42,10 @@ def sleevelengths(encoding):
42
  predicted_class_idx = outputs.logits.argmax(-1).item()
43
  return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
44
 
45
- def imageprocessing(url):
46
- response = requests.get(url)
47
- image = Image.open(BytesIO(response.content))
48
- encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt")
49
- return encoding, image
50
 
51
- def pipes(imagepath):
52
- encoding, image = imageprocessing(imagepath)
53
- # Using threading for faster results
54
  results = [None] * 4
55
 
56
  def update_results(index, func):
@@ -68,21 +63,27 @@ def pipes(imagepath):
68
  for thread in threads:
69
  thread.join()
70
 
71
- dicts = {"top": results[0], "pattern": results[1], "print": results[2], "sleeve_length": results[3]}
72
- return dicts, image
 
 
 
 
 
 
73
 
74
  # Streamlit app UI
75
  st.title("Clothing Classification Pipeline")
76
 
77
  image_url = st.text_input("Enter Image URL")
78
-
79
  if image_url:
80
  start_time = time.time()
81
- results, img = pipes(image_url)
82
- st.image(img.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
83
-
84
- # Display results as JSON
85
- st.write("Classification Results (JSON):")
86
- st.json(results) # Output as JSON
87
 
88
- st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
 
 
 
 
 
 
 
 
5
  from io import BytesIO
6
  import threading
7
  import time
 
8
 
9
+ # Load models and processor only once using Streamlit session state
10
  if 'models_loaded' not in st.session_state:
 
11
  st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
 
12
  st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
 
13
  st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
 
14
  st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
 
15
  st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
16
  st.session_state.models_loaded = True
17
 
18
+ # Define image processing and classification functions
19
+ def imageprocessing(url):
20
+ response = requests.get(url)
21
+ image = Image.open(BytesIO(response.content))
22
+ encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt")
23
+ return encoding, image
24
+
25
  def topwear(encoding):
26
  outputs = st.session_state.top_wear_model(**encoding)
27
  predicted_class_idx = outputs.logits.argmax(-1).item()
 
42
  predicted_class_idx = outputs.logits.argmax(-1).item()
43
  return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
44
 
45
+ # Run all models in parallel
46
+ def pipes(image_url):
47
+ encoding, image = imageprocessing(image_url)
 
 
48
 
 
 
 
49
  results = [None] * 4
50
 
51
  def update_results(index, func):
 
63
  for thread in threads:
64
  thread.join()
65
 
66
+ result_dict = {
67
+ "topwear": results[0],
68
+ "pattern": results[1],
69
+ "print": results[2],
70
+ "sleeve_length": results[3]
71
+ }
72
+
73
+ return result_dict, image
74
 
75
  # Streamlit app UI
76
  st.title("Clothing Classification Pipeline")
77
 
78
  image_url = st.text_input("Enter Image URL")
 
79
  if image_url:
80
  start_time = time.time()
 
 
 
 
 
 
81
 
82
+ try:
83
+ result, img = pipes(image_url)
84
+ st.image(img.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
85
+ st.write("Classification Results (JSON):")
86
+ st.json(result) # Display results in JSON format
87
+ st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
88
+ except Exception as e:
89
+ st.error(f"Error processing the image: {str(e)}")