vishalkatheriya18 commited on
Commit
0bb757b
1 Parent(s): 2a803fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -26
app.py CHANGED
@@ -16,60 +16,81 @@ if 'models_loaded' not in st.session_state:
16
  st.session_state.models_loaded = True
17
 
18
  # Define image processing and classification functions
19
- def imageprocessing(image):
20
- encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt")
21
- st.write(encoding)
22
- return encoding
23
-
24
- def topwear(encoding):
25
  outputs = st.session_state.top_wear_model(**encoding)
26
- predicted_class_idx = outputs.logits.argmax(-1).item()
 
 
27
  return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
28
 
29
- def patterns(encoding):
 
30
  outputs = st.session_state.pattern_model(**encoding)
31
- predicted_class_idx = outputs.logits.argmax(-1).item()
 
 
32
  return st.session_state.pattern_model.config.id2label[predicted_class_idx]
33
 
34
- def prints(encoding):
 
35
  outputs = st.session_state.print_model(**encoding)
36
- predicted_class_idx = outputs.logits.argmax(-1).item()
 
 
37
  return st.session_state.print_model.config.id2label[predicted_class_idx]
38
 
39
- def sleevelengths(encoding):
 
40
  outputs = st.session_state.sleeve_length_model(**encoding)
41
- predicted_class_idx = outputs.logits.argmax(-1).item()
 
 
42
  return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
43
 
 
 
 
 
 
 
 
 
 
 
44
  # Run all models in parallel
45
- def pipes(image):
46
- encoding = imageprocessing(image)
 
47
 
 
48
  results = [None] * 4
49
 
50
- def update_results(index, func):
51
- results[index] = func(encoding)
52
-
53
  threads = [
54
- threading.Thread(target=update_results, args=(0, topwear)),
55
- threading.Thread(target=update_results, args=(1, patterns)),
56
- threading.Thread(target=update_results, args=(2, prints)),
57
- threading.Thread(target=update_results, args=(3, sleevelengths)),
58
  ]
59
 
 
60
  for thread in threads:
61
  thread.start()
 
 
62
  for thread in threads:
63
  thread.join()
64
 
65
- result_dict = {
66
- "topwear": results[0],
 
67
  "pattern": results[1],
68
  "print": results[2],
69
  "sleeve_length": results[3]
70
  }
71
 
72
- return result_dict
73
 
74
  # Streamlit app UI
75
  st.title("Clothing Classification Pipeline")
@@ -79,12 +100,13 @@ if url:
79
  response = requests.get(url)
80
  if response.status_code == 200:
81
  image = Image.open(BytesIO(response.content))
 
82
  st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
83
 
84
  start_time = time.time()
85
 
86
  try:
87
- result = pipes(image)
88
  st.write("Classification Results (JSON):")
89
  st.json(result) # Display results in JSON format
90
  st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
 
16
  st.session_state.models_loaded = True
17
 
18
  # Define image processing and classification functions
19
+ def topwear(encoding,top_wear_model):
20
+ # Make prediction
 
 
 
 
21
  outputs = st.session_state.top_wear_model(**encoding)
22
+ logits = outputs.logits
23
+ predicted_class_idx = logits.argmax(-1).item()
24
+ # Print the result
25
  return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
26
 
27
+ def patterns(encoding,pattern_model):
28
+ # Make prediction
29
  outputs = st.session_state.pattern_model(**encoding)
30
+ logits = outputs.logits
31
+ predicted_class_idx = logits.argmax(-1).item()
32
+ # Print the result
33
  return st.session_state.pattern_model.config.id2label[predicted_class_idx]
34
 
35
+ def prints(encoding,print_model):
36
+ # Make prediction
37
  outputs = st.session_state.print_model(**encoding)
38
+ logits = outputs.logits
39
+ predicted_class_idx = logits.argmax(-1).item()
40
+ # Print the result
41
  return st.session_state.print_model.config.id2label[predicted_class_idx]
42
 
43
+ def sleevelengths(encoding,sleeve_length_model):
44
+ # Make prediction
45
  outputs = st.session_state.sleeve_length_model(**encoding)
46
+ logits = outputs.logits
47
+ predicted_class_idx = logits.argmax(-1).item()
48
+ # Print the result
49
  return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
50
 
51
+ def imageprocessing(url):
52
+ response = requests.get(url)
53
+ if response.status_code == 200:
54
+ image = Image.open(BytesIO(response.content))
55
+ encoding = image_processor(image.convert("RGB"), return_tensors="pt")
56
+ return encoding
57
+
58
+ # Define the function that will be used in each thread
59
+ def call_model(func, encoding, model, results, index):
60
+ results[index] = func(encoding, model)
61
  # Run all models in parallel
62
+ def pipes(imagepath):
63
+ # Process the image once and reuse the encoding
64
+ encoding = imageprocessing(imagepath)
65
 
66
+ # Prepare a list to store the results from each thread
67
  results = [None] * 4
68
 
69
+ # Create threads for each function call
 
 
70
  threads = [
71
+ threading.Thread(target=call_model, args=(topwear, encoding, top_wear_model, results, 0)),
72
+ threading.Thread(target=call_model, args=(patterns, encoding, pattern_model, results, 1)),
73
+ threading.Thread(target=call_model, args=(prints, encoding, print_model, results, 2)),
74
+ threading.Thread(target=call_model, args=(sleevelengths, encoding, sleeve_length_model, results, 3)),
75
  ]
76
 
77
+ # Start all threads
78
  for thread in threads:
79
  thread.start()
80
+
81
+ # Wait for all threads to finish
82
  for thread in threads:
83
  thread.join()
84
 
85
+ # Combine the results into a dictionary
86
+ dicts = {
87
+ "top": results[0],
88
  "pattern": results[1],
89
  "print": results[2],
90
  "sleeve_length": results[3]
91
  }
92
 
93
+ return dicts
94
 
95
  # Streamlit app UI
96
  st.title("Clothing Classification Pipeline")
 
100
  response = requests.get(url)
101
  if response.status_code == 200:
102
  image = Image.open(BytesIO(response.content))
103
+ encoding = image_processor(image.convert("RGB"), return_tensors="pt")
104
  st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
105
 
106
  start_time = time.time()
107
 
108
  try:
109
+ result = pipes(url)
110
  st.write("Classification Results (JSON):")
111
  st.json(result) # Display results in JSON format
112
  st.write(f"Time taken: {time.time() - start_time:.2f} seconds")