File size: 1,972 Bytes
c41dcc5
 
 
 
 
 
8511d74
c41dcc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8511d74
 
c41dcc5
 
 
8511d74
 
c41dcc5
b652aa3
 
c41dcc5
 
 
 
 
 
 
1476ed8
c41dcc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
from transformers import pipeline
from PIL import ImageDraw
import torch

detector = pipeline("zero-shot-object-detection", model="google/owlvit-base-patch32")
depth_estimator = pipeline("depth-estimation", model="Intel/dpt-large")

def visualize_preds(image, predictions):
    new_image = image.copy()
    draw = ImageDraw.Draw(new_image)

    for prediction in predictions:
        box = prediction["box"]
        label = prediction["label"]
        score = prediction["score"]
        xmin, ymin, xmax, ymax = box.values()
        draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
        draw.text((xmin, ymin), f"{label}: {round(score,2)}", fill="white")

    return new_image

def compute_depth(image, preds):
    
    output = depth_estimator(image)
    prediction = torch.nn.functional.interpolate(
        output["predicted_depth"].unsqueeze(1),
        size=image.size[::-1],
        mode="bicubic",
        align_corners=False,
    ).squeeze().cpu().numpy()
    
    output = []

    for pred in preds:
        x = (pred["box"]["xmax"] + pred["box"]["xmin"]) // 2
        y = (pred["box"]["ymax"] + pred["box"]["ymin"]) // 2
        output.append({
            "class": pred["label"],
            "distance": float(prediction[y][x])
        })

    return output

def process(image, text):
    items = text.split(".")
    preds = detector(image, candidate_labels=items)

    return [visualize_preds(image, preds), compute_depth(image, preds)]

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=1):
            image = gr.Image(type="pil")
            name = gr.Textbox(label="Name")
            greet_btn = gr.Button("Greet")
        with gr.Column(scale=1):
            output_detection = gr.Image(type="pil")
            output_distance = gr.JSON(label="Distance")


    greet_btn.click(fn=process, inputs=[image, name], outputs=[output_detection, output_distance], api_name="process")

demo.launch()