Spaces:
Running
Running
# backend.py | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import torch | |
from transformers import SamModel, SamProcessor | |
from torchvision.transforms import v2 | |
from samgeo.text_sam import LangSAM | |
import os | |
import logging | |
preproc = v2.Compose([ | |
v2.PILToTensor(), | |
v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1] | |
]) | |
# Load the necessary models. | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
CHECKPOINT_FILE = os.getenv("SAM_FINETUNE_CHECKPOINT", "checkpoints/bbox_finetune.ckpt") | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
tuned_model = SamModel.from_pretrained("facebook/sam-vit-large").to(device) | |
tuned_model.load_state_dict(torch.load(CHECKPOINT_FILE, | |
map_location=device)) | |
langsam_model = LangSAM("vit_l") | |
def process_image(image: Image, bbox: list[int, int, int, int] = None) -> Image: | |
logging.info("Logging image information.") | |
if bbox is None: | |
# No bbox information. Use default (filters out zeroes) | |
logging.debug("Using default, null bounding box.") | |
bbox = list(map(float, image.getbbox())) # List of floats. | |
inputs = processor(preproc(image), input_boxes=[[bbox]], | |
do_rescale=False, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} # Map objects to our device. | |
mask = get_sidewalk_mask(tuned_model, inputs) | |
# Get tree masks. | |
# Union 'em?? | |
return mask | |
def get_sidewalk_mask(model, inputs) -> Image: | |
logging.info("Calculating mask.") | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs, multimask_output=False) | |
## apply sigmoid | |
mask_probabilities = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
## Convert to numpy for the rest of our stuff. | |
mask_probabilities = mask_probabilities.cpu().numpy().squeeze() | |
## Filter out smaller probs. | |
mask_probabilities[mask_probabilities < 0.5] = 0 | |
## Map probabilities to color intensity linearly. | |
mask_probabilities *= 255 | |
greyscale_img = Image.fromarray(mask_probabilities).convert('L') | |
return greyscale_img | |
def get_tree_masks(image: Image): | |
langsam_model.predict(image, "tree", box_threshold=0.24, text_threshold=0.24) | |
# masks, boxes, phrases, logits = tuned_model.predict(image_pil, bbox) | |
# tree_data = langsam_model.predict(image_pil, text_prompt) | |
# def draw_layer_on_image(model, im: Image, text_prompt: str='sidewalk') -> Image: | |