vishalkatheriya18 commited on
Commit
c6a1409
1 Parent(s): 9274ae9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -70
app.py CHANGED
@@ -1,94 +1,57 @@
1
  import streamlit as st
2
- from transformers import AutoModelForImageClassification, AutoImageProcessor
3
  from PIL import Image
4
  import requests
5
  from io import BytesIO
6
  import time
7
- import torch
8
 
9
- # Load models and processor only once using Streamlit session state
10
  if 'models_loaded' not in st.session_state:
11
- st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
12
- st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
13
- st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
14
- st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
15
- st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
16
  st.session_state.models_loaded = True
17
 
18
- # Define image processing and classification functions
19
- def topwear(encoding):
20
- with torch.no_grad():
21
- outputs = st.session_state.top_wear_model(**encoding)
22
- logits = outputs.logits
23
- predicted_class_idx = logits.argmax(-1).item()
24
- st.write(f"Top Wear: {st.session_state.top_wear_model.config.id2label[predicted_class_idx]}")
25
- return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
26
-
27
- def patterns(encoding):
28
- with torch.no_grad():
29
- outputs = st.session_state.pattern_model(**encoding)
30
- logits = outputs.logits
31
- predicted_class_idx = logits.argmax(-1).item()
32
- st.write(f"Pattern: {st.session_state.pattern_model.config.id2label[predicted_class_idx]}")
33
- return st.session_state.pattern_model.config.id2label[predicted_class_idx]
34
-
35
- def prints(encoding):
36
- with torch.no_grad():
37
- outputs = st.session_state.print_model(**encoding)
38
- logits = outputs.logits
39
- predicted_class_idx = logits.argmax(-1).item()
40
- st.write(f"Print: {st.session_state.print_model.config.id2label[predicted_class_idx]}")
41
- return st.session_state.print_model.config.id2label[predicted_class_idx]
42
-
43
- def sleevelengths(encoding):
44
- with torch.no_grad():
45
- outputs = st.session_state.sleeve_length_model(**encoding)
46
- logits = outputs.logits
47
- predicted_class_idx = logits.argmax(-1).item()
48
- st.write(f"Sleeve Length: {st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]}")
49
- return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
50
-
51
- def imageprocessing(image):
52
- encoding = st.session_state.image_processor(images=image, return_tensors="pt")
53
- return encoding
54
-
55
- # Run all models sequentially
56
- def pipes(image):
57
- # Process the image once and reuse the encoding
58
- encoding = imageprocessing(image)
59
-
60
- # Get results from each model
61
- topwear_result = topwear(encoding)
62
- pattern_result = patterns(encoding)
63
- print_result = prints(encoding)
64
- sleeve_length_result = sleevelengths(encoding)
65
-
66
- # Combine the results into a dictionary
67
- results = {
68
- "top": topwear_result,
69
- "pattern": pattern_result,
70
- "print": print_result,
71
- "sleeve_length": sleeve_length_result
72
- }
73
- st.write(results)
74
- return results
75
 
76
  # Streamlit app UI
77
- st.title("Clothing Classification Pipeline")
78
 
79
  url = st.text_input("Paste image URL here...")
80
  if url:
81
  try:
82
  response = requests.get(url)
83
  if response.status_code == 200:
84
- image = Image.open(BytesIO(response.content))
85
  st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
86
 
87
  start_time = time.time()
88
 
89
- result = pipes(image)
90
- st.write("Classification Results (JSON):")
91
- st.json(result) # Display results in JSON format
 
 
 
 
 
 
 
 
92
  st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
93
  else:
94
  st.error("Failed to load image from URL. Please check the URL.")
 
1
  import streamlit as st
2
+ from ultralytics import YOLO
3
  from PIL import Image
4
  import requests
5
  from io import BytesIO
6
  import time
 
7
 
8
+ # Load the YOLO model only once using Streamlit session state
9
  if 'models_loaded' not in st.session_state:
10
+ st.session_state.yolo_model = YOLO('/kaggle/working/classification_project/yolo_classification/weights/best.pt') # Update with your model path
 
 
 
 
11
  st.session_state.models_loaded = True
12
 
13
+ # Define function for inference using YOLO
14
+ def predict_with_yolo(image):
15
+ # Run inference on the image using the YOLO model
16
+ results = st.session_state.yolo_model(image)
17
+
18
+ # Extract predictions
19
+ predictions = []
20
+ if results:
21
+ for result in results:
22
+ for box in result.boxes:
23
+ class_name = result.names[box.label]
24
+ confidence = box.conf.item() # Convert tensor to a Python float
25
+ predictions.append({
26
+ "Class": class_name,
27
+ "Confidence": confidence
28
+ })
29
+ return predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Streamlit app UI
32
+ st.title("Clothing Detection with YOLO")
33
 
34
  url = st.text_input("Paste image URL here...")
35
  if url:
36
  try:
37
  response = requests.get(url)
38
  if response.status_code == 200:
39
+ image = Image.open(BytesIO(response.content)).convert('RGB')
40
  st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
41
 
42
  start_time = time.time()
43
 
44
+ # Predict using YOLO
45
+ predictions = predict_with_yolo(image)
46
+
47
+ # Display predictions
48
+ if predictions:
49
+ st.write("Predictions:")
50
+ for pred in predictions:
51
+ st.write(f"Class: {pred['Class']}, Confidence: {pred['Confidence']:.2f}")
52
+ else:
53
+ st.write("No objects detected.")
54
+
55
  st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
56
  else:
57
  st.error("Failed to load image from URL. Please check the URL.")