|
import torch |
|
from pathlib import Path |
|
from transformers import CLIPProcessor, CLIPModel |
|
from PIL import Image, ImageDraw |
|
import pytesseract |
|
import requests |
|
import os |
|
from llm import inference, upload_image |
|
|
|
import re |
|
|
|
|
|
cropped_images_dir = "cropped_images" |
|
os.makedirs(cropped_images_dir, exist_ok=True) |
|
|
|
|
|
class YOLOModel: |
|
def __init__(self, model_path="yolov5s.pt"): |
|
""" |
|
Initialize the YOLO model. Downloads YOLOv5 pretrained model if not available. |
|
""" |
|
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True |
|
self.model = torch.hub.load("ultralytics/yolov5", "custom", path=model_path, force_reload=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_clip(self, image, brand_names): |
|
""" |
|
Predict the most probable brand using CLIP. |
|
""" |
|
inputs = self.clip_processor( |
|
text=brand_names, |
|
images=image, |
|
return_tensors="pt", |
|
padding=True |
|
) |
|
|
|
outputs = self.clip_model(**inputs) |
|
logits_per_image = outputs.logits_per_image |
|
probs = logits_per_image.softmax(dim=1) |
|
best_idx = probs.argmax().item() |
|
return brand_names[best_idx], probs[0, best_idx].item() |
|
|
|
|
|
def predict_text(self, image): |
|
try: |
|
|
|
grayscale = image.convert('L') |
|
|
|
|
|
text = pytesseract.image_to_string(grayscale) |
|
|
|
|
|
return text.strip() |
|
except Exception as e: |
|
|
|
print(f"Error during text prediction: {e}") |
|
|
|
|
|
return "" |
|
|
|
|
|
|
|
def predict(self, image_path): |
|
""" |
|
Run YOLO inference on an image. |
|
|
|
:param image_path: Path to the input image |
|
:return: List of predictions with labels and bounding boxes |
|
""" |
|
results = self.model(image_path) |
|
image = Image.open(image_path).convert("RGB") |
|
draw = ImageDraw.Draw(image) |
|
predictions = results.pandas().xyxy[0] |
|
print(f'YOLO predictions:\n\n{predictions}') |
|
output = [] |
|
for idx, row in predictions.iterrows(): |
|
category = row['name'] |
|
confidence = row['confidence'] |
|
bbox = [row["xmin"], row["ymin"], row["xmax"], row["ymax"]] |
|
|
|
|
|
cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) |
|
cropped_image_path = os.path.join(cropped_images_dir, f"crop_{idx}.jpg") |
|
cropped_image.save(cropped_image_path, "JPEG") |
|
|
|
|
|
print(f'Uploading now to image url') |
|
image_url = upload_image.upload_image_to_imgbb(cropped_image_path) |
|
print(f'Image URL received as{image_url}') |
|
|
|
result_llms = inference.get_name(image_url, category) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
''' |
|
# Match category to possible brands |
|
if category in self.category_brands: |
|
possible_brands = self.category_brands[category] |
|
print(f'Predicting with CLIP:\n\n') |
|
predicted_brand, clip_confidence = self.predict_clip(cropped_image, possible_brands) |
|
else: |
|
predicted_brand, clip_confidence = "Unknown", 0.0 |
|
''' |
|
|
|
|
|
detected_text = self.predict_text(cropped_image) |
|
print(f'Details:{detected_text}') |
|
print(f'Predicted brand: {result_llms["model"]}') |
|
|
|
draw.rectangle(bbox, outline="red", width=3) |
|
draw.text( |
|
(bbox[0], bbox[1] - 10), |
|
f'{result_llms["brand"]})', |
|
fill="red" |
|
) |
|
|
|
|
|
output.append({ |
|
"category": category, |
|
"bbox": bbox, |
|
"confidence": confidence, |
|
"category_llm":result_llms["brand"], |
|
"predicted_brand": result_llms["model"], |
|
|
|
"price":result_llms["price"], |
|
"details":result_llms["description"], |
|
"detected_text":detected_text, |
|
"image_url":image_url |
|
}) |
|
|
|
valid_indices = set(range(len(predictions))) |
|
|
|
|
|
for filename in os.listdir(cropped_images_dir): |
|
|
|
if filename.startswith("crop_") and filename.endswith(".jpg"): |
|
|
|
try: |
|
file_idx = int(filename.split("_")[1].split(".")[0]) |
|
if file_idx not in valid_indices: |
|
|
|
file_path = os.path.join(cropped_images_dir, filename) |
|
os.remove(file_path) |
|
print(f"Deleted excess file: {filename}") |
|
except ValueError: |
|
|
|
continue |
|
|
|
return output |
|
|
|
|
|
|