root-sajjan commited on
Commit
ecaf31c
·
verified ·
1 Parent(s): 9ab5c32
Files changed (3) hide show
  1. app.py +47 -0
  2. model.py +172 -0
  3. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from PIL import Image
5
+ import io
6
+
7
+ from pathlib import Path
8
+ from model import YOLOModel
9
+ import shutil
10
+
11
+ yolo = YOLOModel()
12
+
13
+ UPLOAD_FOLDER = Path("./uploads")
14
+ UPLOAD_FOLDER.mkdir(exist_ok=True)
15
+
16
+ app = FastAPI()
17
+
18
+ @app.post("/upload")
19
+ async def upload_image(image: UploadFile = File(...)):
20
+ # print(f'\n\t\tUPLOADED!!!!')
21
+ try:
22
+ file_path = UPLOAD_FOLDER / image.filename
23
+ with file_path.open("wb") as buffer:
24
+ shutil.copyfileobj(image.file, buffer)
25
+ # print(f'Starting to pass into model, {file_path}')
26
+ # Perform YOLO inference
27
+ predictions = yolo.predict(str(file_path))
28
+ # print(f'{predictions} \n\t\t\t\tare predictions')
29
+ # Clean up uploaded file
30
+ file_path.unlink() # Remove file after processing
31
+ return JSONResponse(content={"items": predictions})
32
+
33
+
34
+ except Exception as e:
35
+ return JSONResponse(content={"error": str(e)}, status_code=500)
36
+
37
+ # code to accept the localhost to get images from
38
+ app.add_middleware(
39
+ CORSMiddleware,
40
+ allow_origins=["http://192.168.56.1:3000", "http://192.168.56.1:3001"],
41
+ allow_methods=["*"],
42
+ allow_headers=["*"],
43
+ )
44
+
45
+ if __name__ == "__main__":
46
+ import uvicorn
47
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
model.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ from PIL import Image, ImageDraw
5
+ import pytesseract
6
+ import requests
7
+ import os
8
+ from llm import inference, upload_image
9
+
10
+ import re
11
+
12
+
13
+ cropped_images_dir = "cropped_images"
14
+ os.makedirs(cropped_images_dir, exist_ok=True)
15
+
16
+ # Load YOLO model
17
+ class YOLOModel:
18
+ def __init__(self, model_path="yolov5s.pt"):
19
+ """
20
+ Initialize the YOLO model. Downloads YOLOv5 pretrained model if not available.
21
+ """
22
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
23
+ self.model = torch.hub.load("ultralytics/yolov5", "custom", path=model_path, force_reload=True)
24
+ # self.model2 = YOLOv10.from_pretrained("Ultralytics/Yolov8")
25
+ # print(f'YOLO Model:\n\n{self.model}')
26
+ # self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
27
+
28
+ # # print(f'CLIP Model:\n\n{self.clip_model}')
29
+ # self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
30
+ # self.category_brands = {
31
+ # "electronics": ["Samsung", "Apple", "Sony", "LG", "Panasonic"],
32
+ # "furniture": ["Ikea", "Ashley", "La-Z-Boy", "Wayfair", "West Elm"],
33
+ # "appliances": ["Whirlpool", "GE", "Samsung", "LG", "Bosch"],
34
+ # "vehicles": ["Tesla", "Toyota", "Ford", "Honda", "Chevrolet"],
35
+ # "chair": ["Ikea", "Ashley", "Wayfair", "La-Z-Boy", "Herman Miller"],
36
+ # "microwave": ["Samsung", "Panasonic", "Sharp", "LG", "Whirlpool"],
37
+ # "table": ["Ikea", "Wayfair", "Ashley", "CB2", "West Elm"],
38
+ # "oven": ["Whirlpool", "GE", "Samsung", "Bosch", "LG"],
39
+ # "potted plant": ["The Sill", "PlantVine", "Lowe's", "Home Depot", "UrbanStems"],
40
+ # "couch": ["Ikea", "Ashley", "Wayfair", "La-Z-Boy", "CushionCo"],
41
+ # "cow": ["Angus", "Hereford", "Jersey", "Holstein", "Charolais"],
42
+ # "bed": ["Tempur-Pedic", "Ikea", "Sealy", "Serta", "Sleep Number"],
43
+ # "tv": ["Samsung", "LG", "Sony", "Vizio", "TCL"],
44
+ # "bin": ["Rubbermaid", "Sterilite", "Hefty", "Glad", "Simplehuman"],
45
+ # "refrigerator": ["Whirlpool", "GE", "Samsung", "LG", "Bosch"],
46
+ # "laptop": ["Dell", "HP", "Apple", "Lenovo", "Asus"],
47
+ # "smartphone": ["Apple", "Samsung", "Google", "OnePlus", "Huawei"],
48
+ # "camera": ["Canon", "Nikon", "Sony", "Fujifilm", "Panasonic"],
49
+ # "toaster": ["Breville", "Cuisinart", "Black+Decker", "Hamilton Beach", "Oster"],
50
+ # "fan": ["Dyson", "Honeywell", "Lasko", "Vornado", "Bionaire"],
51
+ # "vacuum cleaner": ["Dyson", "Shark", "Roomba", "Hoover", "Bissell"]
52
+ # }
53
+
54
+
55
+ def predict_clip(self, image, brand_names):
56
+ """
57
+ Predict the most probable brand using CLIP.
58
+ """
59
+ inputs = self.clip_processor(
60
+ text=brand_names,
61
+ images=image,
62
+ return_tensors="pt",
63
+ padding=True
64
+ )
65
+ # print(f'Inputs to clip processor:{inputs}')
66
+ outputs = self.clip_model(**inputs)
67
+ logits_per_image = outputs.logits_per_image
68
+ probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
69
+ best_idx = probs.argmax().item()
70
+ return brand_names[best_idx], probs[0, best_idx].item()
71
+
72
+
73
+ def predict_text(self, image):
74
+ grayscale = image.convert('L')
75
+ text = pytesseract.image_to_string(grayscale)
76
+ return text.strip()
77
+
78
+
79
+ def predict(self, image_path):
80
+ """
81
+ Run YOLO inference on an image.
82
+
83
+ :param image_path: Path to the input image
84
+ :return: List of predictions with labels and bounding boxes
85
+ """
86
+ results = self.model(image_path)
87
+ image = Image.open(image_path).convert("RGB")
88
+ draw = ImageDraw.Draw(image)
89
+ predictions = results.pandas().xyxy[0] # Get predictions as pandas DataFrame
90
+ print(f'YOLO predictions:\n\n{predictions}')
91
+ output = []
92
+ for idx, row in predictions.iterrows():
93
+ category = row['name']
94
+ confidence = row['confidence']
95
+ bbox = [row["xmin"], row["ymin"], row["xmax"], row["ymax"]]
96
+
97
+ # Crop the detected region
98
+ cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
99
+ cropped_image_path = os.path.join(cropped_images_dir, f"crop_{idx}.jpg")
100
+ cropped_image.save(cropped_image_path, "JPEG")
101
+
102
+ # uploading to cloud for getting URL to pass into LLM
103
+ print(f'Uploading now to image url')
104
+ image_url = upload_image.upload_image_to_imgbb(cropped_image_path)
105
+ print(f'Image URL received as{image_url}')
106
+ # inferencing llm for possible brands
107
+ result_llms = inference.get_name(image_url, category)
108
+ # possible_brands_llm = re.findall(r"-\s*(.+)", possible_brands_mixed)
109
+
110
+ # if len(possible_brands_llm)>0:
111
+ # predicted_brand, clip_confidence = self.predict_clip(cropped_image, possible_brands_llm)
112
+ # else:
113
+ # predicted_brand, clip_confidence = "Unknown", 0.0
114
+
115
+
116
+ '''
117
+ # Match category to possible brands
118
+ if category in self.category_brands:
119
+ possible_brands = self.category_brands[category]
120
+ print(f'Predicting with CLIP:\n\n')
121
+ predicted_brand, clip_confidence = self.predict_clip(cropped_image, possible_brands)
122
+ else:
123
+ predicted_brand, clip_confidence = "Unknown", 0.0
124
+ '''
125
+
126
+
127
+ detected_text = self.predict_text(cropped_image)
128
+ print(f'Details:{detected_text}')
129
+ print(f'Predicted brand: {result_llms["model"]}')
130
+ # Draw bounding box and label on the image
131
+ draw.rectangle(bbox, outline="red", width=3)
132
+ draw.text(
133
+ (bbox[0], bbox[1] - 10),
134
+ f'{result_llms["brand"]})',
135
+ fill="red"
136
+ )
137
+
138
+ # Append result
139
+ output.append({
140
+ "category": category,
141
+ "bbox": bbox,
142
+ "confidence": confidence,
143
+ "category_llm":result_llms["brand"],
144
+ "predicted_brand": result_llms["model"],
145
+ # "clip_confidence": clip_confidence,
146
+ "price":result_llms["price"],
147
+ "details":result_llms["description"],
148
+ "detected_text":detected_text,
149
+ })
150
+
151
+ valid_indices = set(range(len(predictions)))
152
+
153
+ # Iterate over all files in the directory
154
+ for filename in os.listdir(cropped_images_dir):
155
+ # Check if the filename matches the pattern for cropped images
156
+ if filename.startswith("crop_") and filename.endswith(".jpg"):
157
+ # Extract the index from the filename
158
+ try:
159
+ file_idx = int(filename.split("_")[1].split(".")[0])
160
+ if file_idx not in valid_indices:
161
+ # Delete the file if its index is not valid
162
+ file_path = os.path.join(cropped_images_dir, filename)
163
+ os.remove(file_path)
164
+ print(f"Deleted excess file: {filename}")
165
+ except ValueError:
166
+ # Skip files that don't match the pattern
167
+ continue
168
+
169
+ return output
170
+
171
+
172
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi[all]
2
+ ultralytics==8.2.52
3
+ ultralytics-thop==2.0.0
4
+ torch==2.2.2
5
+ pytesseract==0.3.10
6
+ pillow==11.0.0
7
+ nltk==3.9.1
8
+ transformers @ git+https://github.com/huggingface/transformers@e6889961761eba27ac13a30aa12444a6f4406f78
9
+ huggingface-hub==0.24.6
10
+ # python 3.11.4