RingMo-SAM / app.py
AI-Cyber's picture
Update app.py
65b7b9e verified
raw
history blame
21.3 kB
# -*- coding: utf-8 -*-
import sys
import io
import requests
import json
import base64
from PIL import Image
import numpy as np
import gradio as gr
import mmengine
from mmengine import Config, get
import argparse
import os
import cv2
import yaml
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import datasets
import models
import numpy as np
from torchvision import transforms
from mmcv.runner import load_checkpoint
import visual_utils
from PIL import Image
from models.utils_prompt import get_prompt_inp, pre_prompt, pre_scatter_prompt, get_prompt_inp_scatter
device = torch.device("cpu")
def batched_predict(model, inp, coord, bsize):
with torch.no_grad():
model.gen_feat(inp)
n = coord.shape[1]
ql = 0
preds = []
while ql < n:
qr = min(ql + bsize, n)
pred = model.query_rgb(coord[:, ql: qr, :])
preds.append(pred)
ql = qr
pred = torch.cat(preds, dim=1)
return pred, preds
def tensor2PIL(tensor):
toPIL = transforms.ToPILImage()
return toPIL(tensor)
def Decoder1_optical_instance(image_input):
with open('configs/fine_tuning_one_decoder.yaml', 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
model = models.make(config['model']).cpu()
sam_checkpoint = torch.load("./save/model_epoch_last.pth", map_location='cpu')
model.load_state_dict(sam_checkpoint, strict=False)
model.eval()
# img = np.array(image_input).copy()
label2color = visual_utils.Label2Color(cmap=visual_utils.color_map('Unify_double'))
# image_input.save(f'./save/visual_fair1m/input_img.png', quality=5)
img = transforms.Resize([1024, 1024])(image_input)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225])])
input_img = transform(img)
input_img = input_img.unsqueeze(0)
image_embedding = model.image_encoder(input_img) # torch.Size([1, 256, 64, 64])
sparse_embeddings, dense_embeddings, scatter_embeddings = model.prompt_encoder(
points=None,
boxes=None,
masks=None,
scatter=None)
# 目标类预测decoder
low_res_masks, iou_predictions = model.mask_decoder(
image_embeddings=image_embedding,
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False
)
pred = model.postprocess_masks(low_res_masks, model.inp_size, model.inp_size)
_, prediction = pred.max(dim=1)
prediction_to_save = label2color(prediction.cpu().numpy().astype(np.uint8))[0]
return prediction_to_save
def Decoder1_optical_terrain(image_input):
with open('configs/fine_tuning_one_decoder.yaml', 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
model = models.make(config['model']).cpu()
sam_checkpoint = torch.load("./save/model_epoch_last.pth", map_location='cpu')
model.load_state_dict(sam_checkpoint, strict=False)
model.eval()
denorm = visual_utils.Denormalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225])
label2color = visual_utils.Label2Color(cmap=visual_utils.color_map('Unify_Vai'))
# image_input.save(f'./save/visual_fair1m/input_img.png', quality=5)
img = transforms.Resize([1024, 1024])(image_input)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225])])
input_img = transform(img)
input_img = torch.unsqueeze(input_img, dim=0)
# input_img = transforms.ToTensor()(img).unsqueeze(0)
image_embedding = model.image_encoder(input_img) # torch.Size([1, 256, 64, 64])
sparse_embeddings, dense_embeddings, scatter_embeddings = model.prompt_encoder(
points=None,
boxes=None,
masks=None,
scatter=None)
low_res_masks_instanse, iou_predictions = model.mask_decoder(
image_embeddings=image_embedding,
# image_embeddings=image_embedding.unsqueeze(0),
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
# multimask_output=multimask_output,
multimask_output=False
)
# 地物类预测decoder
low_res_masks, iou_predictions_2 = model.mask_decoder_diwu(
image_embeddings=image_embedding,
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
# multimask_output=False,
multimask_output=True,
) # B*C+1*H*W
pred_instance = model.postprocess_masks(low_res_masks_instanse, model.inp_size, model.inp_size)
pred = model.postprocess_masks(low_res_masks, model.inp_size, model.inp_size)
pred = torch.softmax(pred,dim=1)
pred_instance = torch.softmax(pred_instance,dim=1)
_, prediction = pred.max(dim=1)
prediction[prediction==12]=0 #把第二个decoder里得背景变成0
print(torch.unique(prediction))
_, prediction_instance = pred_instance.max(dim=1)
print(torch.unique(prediction_instance))
prediction_sum = prediction + prediction_instance #没有冲突的位置就会正常猜测
print(torch.unique(prediction_sum))
prediction_tmp = prediction_sum.clone()
prediction_tmp[prediction_tmp==1] = 255
prediction_tmp[prediction_tmp==2] = 255
prediction_tmp[prediction_tmp==5] = 255
prediction_tmp[prediction_tmp==6] = 255
prediction_tmp[prediction_tmp==14] = 255
# prediction_tmp[prediction_tmp==0] = 255 #同时是背景
# index = prediction_tmp != 255
pred[:, 0][prediction_tmp == 255]=100 #把已经决定的像素位置的背景预测概率设置为最大
pred_instance[:, 0][prediction_tmp == 255]=100#把已经决定的像素位置的背景预测概率设置为最大
buchong = torch.zeros([1,2,1024,1024])
pred = torch.cat((pred, buchong),dim=1)
# print(torch.unique(torch.argmax(pred,dim=1)))
# Decoder1_logits = torch.zeros([1,15,1024,1024]).cuda()
Decoder2_logits = torch.zeros([1,15,1024,1024])
Decoder2_logits[:,0,...] = pred[:,0,...]
Decoder2_logits[:,5,...] = pred_instance[:,5,...]
Decoder2_logits[:,14,...] = pred_instance[:,14,...]
Decoder2_logits[:,1,...] = pred[:,1,...]
Decoder2_logits[:,2,...] = pred[:,2,...]
Decoder2_logits[:,6,...] = pred[:,6,...]
# Decoder_logits = Decoder1_logits+Decoder2_logits
pred_chongtu = torch.argmax(Decoder2_logits, dim=1)
# pred_pred = torch.argmax(Decoder1_logits, dim=1)
pred_predinstance = torch.argmax(Decoder2_logits, dim=1)
print(torch.unique(pred_chongtu))
pred_chongtu[prediction_tmp == 255] = 0
prediction_sum[prediction_tmp!=255] = 0
prediction_final = (pred_chongtu + prediction_sum).cpu().numpy()
prediction_to_save = label2color(prediction_final)[0]
return prediction_to_save
def Multi_box_prompts(input_prompt):
with open('configs/fine_tuning_one_decoder.yaml', 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
model = models.make(config['model']).cpu()
sam_checkpoint = torch.load("./save/model_epoch_last.pth", map_location='cpu')
model.load_state_dict(sam_checkpoint, strict=False)
model.eval()
label2color = visual_utils.Label2Color(cmap=visual_utils.color_map('Unify_double'))
# image_input.save(f'./save/visual_fair1m/input_img.png', quality=5)
img = transforms.Resize([1024, 1024])(input_prompt["image"])
input_img = transforms.ToTensor()(img).unsqueeze(0)
image_embedding = model.image_encoder(input_img) # torch.Size([1, 256, 64, 64])
sparse_embeddings, dense_embeddings, scatter_embeddings = model.prompt_encoder(
points=None,
boxes=None,
masks=None,
scatter=None)
# 目标类预测decoder
low_res_masks, iou_predictions = model.mask_decoder(
image_embeddings=image_embedding,
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False
)
pred = model.postprocess_masks(low_res_masks, model.inp_size, model.inp_size)
_, prediction = pred.max(dim=1)
prediction_to_save = label2color(prediction.cpu().numpy().astype(np.uint8))[0]
def find_instance(image_map):
BACKGROUND = 0
steps = [[1, 0], [0, 1], [-1, 0], [0, -1], [1, 1], [1, -1], [-1, 1], [-1, -1]]
instances = []
def bfs(x, y, category_id):
nonlocal image_map, steps
instance = {(x, y)}
q = [(x, y)]
image_map[x, y] = BACKGROUND
while len(q) > 0:
x, y = q.pop(0)
# print(x, y, image_map[x][y])
for step in steps:
xx = step[0] + x
yy = step[1] + y
if 0 <= xx < len(image_map) and 0 <= yy < len(image_map[0]) \
and image_map[xx][yy] == category_id: # and (xx, yy) not in q:
q.append((xx, yy))
instance.add((xx, yy))
image_map[xx, yy] = BACKGROUND
return instance
image_map = image_map[:]
for i in range(len(image_map)):
for j in range(len(image_map[i])):
category_id = image_map[i][j]
if category_id == BACKGROUND:
continue
instances.append(bfs(i, j, category_id))
return instances
prompts = find_instance(np.uint8(np.array(input_prompt["mask"]).sum(-1) != 0))
img_mask = np.array(img).copy()
def get_box(prompt):
xs = []
ys = []
for x, y in prompt:
xs.append(x)
ys.append(y)
return [[min(xs), min(ys)], [max(xs), max(ys)]]
def in_box(point, box):
left_up, right_down = box
x, y = point
return x >= left_up[0] and x <= right_down[0] and y >= left_up[1] and y <= right_down[1]
def draw_box(box_outer, img, radius=4):
radius -= 1
left_up_outer, right_down_outer = box_outer
box_inner = [list(np.array(left_up_outer) + radius),
list(np.array(right_down_outer) - radius)]
for x in range(len(img)):
for y in range(len(img[x])):
if in_box([x, y], box_outer):
img_mask[x, y] = (1, 1, 1)
if in_box([x, y], box_outer) and (not in_box([x, y], box_inner)):
img[x, y] = (255, 0, 0)
return img
for prompt in prompts:
box = get_box(prompt)
output = draw_box(box, prediction_to_save) * (img_mask==1)
return output
def Decoder2_SAR(SAR_image, SAR_prompt):
with open('configs/multi_mo_multi_task_sar_prompt.yaml', 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
model = models.make(config['model']).cpu()
sam_checkpoint = torch.load("./save/SAR/model_epoch_last.pth", map_location='cpu')
model.load_state_dict(sam_checkpoint, strict=True)
model.eval()
denorm = visual_utils.Denormalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225])
label2color = visual_utils.Label2Color(cmap=visual_utils.color_map('Unify_YIJISAR'))
img = transforms.Resize([1024, 1024])(SAR_image)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225])])
input_img = transform(img)
input_img = torch.unsqueeze(input_img, dim=0)
# input_img = transforms.ToTensor()(img).unsqueeze(0)
# input_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225])
filp_flag = torch.Tensor([False])
image_embedding = model.image_encoder(input_img)
# scattter_prompt = cv2.imread(scatter_file_, cv2.IMREAD_UNCHANGED)
# scattter_prompt = get_prompt_inp_scatter(name[0].replace('gt', 'JIHUAFENJIE'))
SAR_prompt = cv2.imread(SAR_prompt, cv2.IMREAD_UNCHANGED)
scatter_torch = pre_scatter_prompt(SAR_prompt, filp_flag, device=input_img.device)
scatter_torch = scatter_torch.unsqueeze(0)
scatter_torch = torch.nn.functional.interpolate(scatter_torch, size=(256, 256))
sparse_embeddings, dense_embeddings, scatter_embeddings = model.prompt_encoder(
points=None,
boxes=None,
masks=None,
scatter=scatter_torch)
# 地物类预测decoder
low_res_masks, iou_predictions_2 = model.mask_decoder_diwu(
image_embeddings=image_embedding,
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
# multimask_output=False,
multimask_output=True,
) # B*C+1*H*W
pred = model.postprocess_masks(low_res_masks, model.inp_size, model.inp_size)
_, prediction = pred.max(dim=1)
prediction = prediction.cpu().numpy()
prediction_to_save = label2color(prediction)[0]
return prediction_to_save
examples1_instance = [
['./images/optical/isaid/_P0007_1065_319_image.png'],
['./images/optical/isaid/_P0466_1068_420_image.png'],
['./images/optical/isaid/_P0897_146_34_image.png'],
['./images/optical/isaid/_P1397_844_904_image.png'],
['./images/optical/isaid/_P2645_883_965_image.png'],
['./images/optical/isaid/_P1398_1290_630_image.png']
]
examples1_terrain = [
['./images/optical/vaihingen/top_mosaic_09cm_area2_105_image.png'],
['./images/optical/vaihingen/top_mosaic_09cm_area4_227_image.png'],
['./images/optical/vaihingen/top_mosaic_09cm_area20_142_image.png'],
['./images/optical/vaihingen/top_mosaic_09cm_area24_128_image.png'],
['./images/optical/vaihingen/top_mosaic_09cm_area27_34_image.png']
]
examples1_multi_box = [
['./images/optical/isaid/_P0007_1065_319_image.png'],
['./images/optical/isaid/_P0466_1068_420_image.png'],
['./images/optical/isaid/_P0897_146_34_image.png'],
['./images/optical/isaid/_P1397_844_904_image.png'],
['./images/optical/isaid/_P2645_883_965_image.png'],
['./images/optical/isaid/_P1398_1290_630_image.png']
]
examples2 = [
['./images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_4_image.png', './images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_4.png'],
['./images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_15_image.png', './images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_15.png'],
['./images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_24_image.png', './images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_24.png'],
['./images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_41_image.png', './images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_41.png'],
['./images/sar/YIJISARGF3_MYN_QPSI_999996_E121.2_N30.3_20160815_L1A_L10002015572_ampl_150_image.png', './images/sar/YIJISARGF3_MYN_QPSI_999996_E121.2_N30.3_20160815_L1A_L10002015572_ampl_150.png']
]
# RingMo-SAM designs two new promptable forms based on the characteristics of multimodal remote sensing images:
# multi-boxes prompt and SAR polarization scatter prompt.
title = "RingMo-SAM:A Foundation Model for Segment Anything in Multimodal Remote Sensing Images<br> \
<div align='center'> \
<h2><a href='https://ieeexplore.ieee.org/document/10315957' target='_blank' rel='noopener'>[paper]</a> \
<br> \
<image src='file/RingMo-SAM.gif' width='720px' /> \
<h2>RingMo-SAM can not only segment anything in optical and SAR remote sensing data, but also identify object categories.<h2> \
</div> \
"
# <a href='https://github.com/AICyberTeam' target='_blank' rel='noopener'>[code]</a></h2> \
# with gr.Blocks() as demo:
# image_input = gr.Image(type='pil', label='Input Img')
# image_output = gr.Image(label='Segment Result', type='numpy')
Decoder_optical_instance_io = gr.Interface(fn=Decoder1_optical_instance,
inputs=[gr.Image(type='pil', label='optical_instance_img(光学图像)')],
outputs=[gr.Image(label='segment_result', type='numpy')],
# title=title,
description="<p> \
Instance_Decoder:<br>\
Instance-type objects (such as vehicle, aircraft, ship, etc.) have a smaller proportion. <br>\
Our decoder can decouple the SAM's mask decoder into instance category decoder and terrain category decoder to ensure that the model fits adequately to both types of data. <br>\
Choose an example below, or, upload optical instance images to be tested. <br>\
Examples below were never trained and are randomly selected for testing in the wild. <br>\
</p>",
allow_flagging='auto',
examples=examples1_instance,
cache_examples=False,
)
Decoder_optical_terrain_io = gr.Interface(fn=Decoder1_optical_terrain,
inputs=[gr.Image(type='pil', label='optical_terrain_img(光学图像)')],
# inputs=[gr.Image(type='pil', label='optical_img(光学图像)'), gr.Image(type='pil', label='SAR_img(SAR图像)'), gr.Image(type='pil', label='SAR_prompt(偏振散射提示)')],
outputs=[gr.Image(label='segment_result', type='numpy')],
# title=title,
description="<p> \
Terrain_Decoder:<br>\
Terrain-type objects (such as vegetation, land, river, etc.) have a larger proportion. <br>\
Our decoder can decouple the SAM's mask decoder into instance category decoder and terrain category decoder to ensure that the model fits adequately to both types of data. <br>\
Choose an example below, or, upload optical terrain images to be tested. <br>\
Examples below were never trained and are randomly selected for testing in the wild. <br>\
</p>",
allow_flagging='auto',
examples=examples1_terrain,
cache_examples=False,
)
Decoder_multi_box_prompts_io = gr.Interface(fn=Multi_box_prompts,
inputs=[gr.ImageMask(brush_radius=4, type='pil', label='input_img(图像)')],
outputs=[gr.Image(label='segment_result', type='numpy')],
# title=title,
description="<p> \
Multi-box Prompts:<br>\
Multiple boxes are sequentially encoded as concated sparse high-dimensional feature embedding, \
the corresponding multiple high-dimensional features are concated together into a high-dimensional feature vector as part of the sparse embedding. <br>\
Choose an example below, or, upload images to be tested, and then draw multi-boxes. <br>\
Examples below were never trained and are randomly selected for testing in the wild. <br>\
</p>",
allow_flagging='auto',
examples=examples1_multi_box,
cache_examples=False,
)
Decoder_SAR_io = gr.Interface(fn=Decoder2_SAR,
inputs=[gr.Image(type='pil', label='SAR_img(SAR图像)'), gr.Image(type='filepath', label='SAR_prompt(偏振散射提示)')],
outputs=[gr.Image(label='segment_result', type='numpy')],
description="<p> \
SAR Polarization Scatter Prompts:<br>\
Different terrain categories usually exhibit different scattering properties. \
Therefore, we code network for coded mapping of these SAR polarization scatter prompts to the corresponding SAR images, \
which improves the segmentation results of SAR images. <br>\
Choose an example below, or, upload SAR images and the corresponding polarization scatter prompts to be tested. <br>\
Examples below were never trained and are randomly selected for testing in the wild. <br>\
</p>",
allow_flagging='auto',
examples=examples2,
cache_examples=False,
)
# Decoder1_io.launch(server_name="0.0.0.0", server_port=34311)
# Decoder1_io.launch(enable_queue=False)
# demo = gr.TabbedInterface([Decoder1_io, Decoder2_io], ['Instance_Decoder', 'Terrain_Decoder'], title=title)
demo = gr.TabbedInterface([Decoder_optical_instance_io, Decoder_optical_terrain_io, Decoder_multi_box_prompts_io, Decoder_SAR_io], ['optical_instance_img(光学图像)', 'optical_terrain_img(光学图像)', 'multi_box_prompts(多框提示)', 'SAR_img(偏振散射提示)'], title=title).launch()
# -