File size: 2,504 Bytes
a073fdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fe2d10
a073fdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# backend.py
import numpy as np
from PIL import Image, ImageDraw
import torch
from transformers import SamModel, SamProcessor
from torchvision.transforms import v2
from samgeo.text_sam import LangSAM
import os
import logging


preproc = v2.Compose([
    v2.PILToTensor(),
    v2.ToDtype(torch.float32, scale=True),  # to float32 in [0, 1]
])


# Load the necessary models.
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
CHECKPOINT_FILE = os.getenv("SAM_FINETUNE_CHECKPOINT", "checkpoints/bbox_finetune.ckpt")

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
tuned_model = SamModel.from_pretrained("facebook/sam-vit-large").to(device)
tuned_model.load_state_dict(torch.load(CHECKPOINT_FILE,
                                       map_location=device))
langsam_model = LangSAM("vit_l")


def process_image(image: Image, bbox: list[int, int, int, int] = None) -> Image:
    logging.info("Logging image information.")
    if bbox is None:
        # No bbox information. Use default (filters out zeroes)
        logging.debug("Using default, null bounding box.")
        bbox = list(map(float, image.getbbox()))  # List of floats.
    inputs = processor(preproc(image), input_boxes=[[bbox]], 
                       do_rescale=False, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}  # Map objects to our device.

    mask = get_sidewalk_mask(tuned_model, inputs)
    # Get tree masks.
    # Union 'em??
    return mask


def get_sidewalk_mask(model, inputs) -> Image:
    logging.info("Calculating mask.")
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs, multimask_output=False)
    ## apply sigmoid
    mask_probabilities = torch.sigmoid(outputs.pred_masks.squeeze(1))
    ## Convert to numpy for the rest of our stuff.
    mask_probabilities = mask_probabilities.cpu().numpy().squeeze()

    ## Filter out smaller probs.
    mask_probabilities[mask_probabilities < 0.5] = 0

    ## Map probabilities to color intensity linearly.
    mask_probabilities *= 255

    greyscale_img = Image.fromarray(mask_probabilities).convert('L')
    return greyscale_img


def get_tree_masks(image: Image):
    langsam_model.predict(image, "tree", box_threshold=0.24, text_threshold=0.24)
    

# masks, boxes, phrases, logits = tuned_model.predict(image_pil, bbox)
# tree_data = langsam_model.predict(image_pil, text_prompt)

# def draw_layer_on_image(model, im: Image, text_prompt: str='sidewalk') -> Image: