File size: 3,439 Bytes
3aba7d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from langchain.tools import BaseTool
from PIL import Image, ImageDraw
import requests
from dotenv import load_dotenv
import os
load_dotenv()


def object_detection_query(filepath):
    API_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50"
    headers = {"Authorization": "Bearer " + os.environ['HUGGINGFACEHUB_API_TOKEN']}
    with open(filepath, "rb") as f:
        data = f.read()
    response = requests.post(API_URL, headers=headers, data=data)
    return response.json()

def bounding_box(filepath):
    # Generate an output
    output = object_detection_query(filepath)

    # load the image
    image = Image.open(filepath).convert('RGB')

    # create a drawing object
    draw = ImageDraw.Draw(image)

    # Draw boxes and labels on the image
    for detection in output:
        label = detection['label']
        score = detection['score']
        box = detection['box']

        # Draw the box
        draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], outline="red", width=2)

        # Draw the label and score
        text = f"{label} ({score:.2f})"
        draw.text((box['xmin'], box['ymin']-10), text, fill='red')

    return image

def captioning_query(filepath):
    API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
    headers = {"Authorization": "Bearer " + os.environ['HUGGINGFACEHUB_API_TOKEN']}
    with open(filepath, "rb") as f:
        data = f.read()
    response = requests.post(API_URL, headers=headers, data=data)
    return response.json()

class ImageCaptionTools(BaseTool):
    name = "Image_Caption_Tools"
    description = "Use this tool with any given image path to receive a personalized description, poem, story, or more. "\
                  "Ideal for agents seeking tailored insights. "\
                  "Let the tool craft content based on your image for a unique perspective."

    def _run(self, image_path) -> str:
        """Use the tool."""
        result = captioning_query(image_path)
        text = result[0]['generated_text']
        return text

    async def _arun(self, query: str) -> str:
        """Use the tool asynchronously."""
        raise NotImplementedError("custom_search does not support async")
    

class ObjectDetectionTool(BaseTool):
    name = "Object_Detection_Tool"
    description = "Object Detection Tool: Use this tool to detect objects in an image. Provide the image path, " \
                  "and it will return a list of detected objects. Each element in the list is in the format: " \
                  "[x1, y1, x2, y2] class_name confidence_score. This tool focuses on object detection, providing " \
                  "locations of objects in the image. For image descriptions or other insights, explore additional tools."

    def _run(self, image_path) -> str:
        """Use the tool."""
        results = object_detection_query(image_path)
        detections = ""
        for result in results:
            box = result['box']
            detections += '[{}, {}, {}, {}]'.format(int(box['xmin']), int(box['ymin']), int(box['xmax']), int(box['ymax']))
            detections += ' {}'.format(result['label'])
            detections += ' {}\n'.format(result['score'])
        return detections

    async def _arun(self, query: str) -> str:
        """Use the tool asynchronously."""
        raise NotImplementedError("custom_search does not support async")