Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import random | |
import time | |
import json | |
import os | |
from loguru import logger | |
from decouple import config | |
import io | |
import torch | |
import numpy as np | |
import torch | |
import cv2 | |
from PIL import Image | |
from segment_anything import sam_model_registry, SamPredictor | |
import spaces | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
print(torch.version.cuda) | |
device = torch.cuda.get_device_name(torch.cuda.current_device()) | |
print(device) | |
sam_checkpoint = "sam-hq/models/sam_hq_vit_h.pth" | |
model_type = "vit_h" | |
device = "cuda" | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
sam.to(device=device) | |
predictor = SamPredictor(sam) | |
def generate_image(prompt, image): | |
predictor.set_image(image) | |
prompt = json.loads(prompt) | |
input_points = np.array(prompt['input_points']) | |
input_labels = np.array(prompt['input_labels']) | |
mask, _, _ = predictor.predict( | |
point_coords=input_points, | |
point_labels=input_labels, | |
box=None, | |
multimask_output=False, | |
hq_token_only=True, | |
) | |
rgb_array = np.zeros((mask.shape[1], mask.shape[2], 3), dtype=np.uint8) | |
rgb_array[mask[0]] = 255 | |
result = Image.fromarray(rgb_array) | |
return result | |
if __name__ == "__main__": | |
demo = gr.Interface(fn=generate_image, inputs=[ | |
"text", | |
gr.Image(image_mode='RGB', type="numpy") | |
], | |
outputs=[ | |
gr.Image(type="numpy", image_mode='RGB') | |
]) | |
demo.launch(debug=True) | |
logger.debug('demo.launch()') | |