|
import gradio as gr |
|
import numpy as np |
|
import random |
|
import time |
|
import json |
|
import os |
|
from loguru import logger |
|
from decouple import config |
|
import io |
|
import torch |
|
import numpy as np |
|
import torch |
|
import cv2 |
|
from PIL import Image |
|
|
|
from segment_anything_hq import sam_model_registry, SamPredictor |
|
|
|
import spaces |
|
|
|
print(f"Is CUDA available: {torch.cuda.is_available()}") |
|
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
print(torch.version.cuda) |
|
device = torch.cuda.get_device_name(torch.cuda.current_device()) |
|
print(device) |
|
|
|
sam_checkpoint = "sam_hq_vit_h.pth" |
|
model_type = "vit_h" |
|
device = "cuda" |
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
sam.to(device=device) |
|
predictor = SamPredictor(sam) |
|
|
|
@spaces.GPU(duration=5) |
|
def generate_image(prompt, image): |
|
predictor.set_image(image) |
|
|
|
prompt = json.loads(prompt) |
|
input_points = np.array(prompt['input_points']) |
|
input_labels = np.array(prompt['input_labels']) |
|
|
|
mask, _, _ = predictor.predict( |
|
point_coords=input_points, |
|
point_labels=input_labels, |
|
box=None, |
|
multimask_output=False, |
|
hq_token_only=True, |
|
) |
|
|
|
rgb_array = np.zeros((mask.shape[1], mask.shape[2], 3), dtype=np.uint8) |
|
rgb_array[mask[0]] = 255 |
|
result = Image.fromarray(rgb_array) |
|
|
|
return result |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = gr.Interface(fn=generate_image, inputs=[ |
|
"text", |
|
gr.Image(image_mode='RGB', type="numpy") |
|
], |
|
outputs=[ |
|
gr.Image(type="numpy", image_mode='RGB') |
|
]) |
|
demo.launch(debug=True) |
|
logger.debug('demo.launch()') |
|
|