Spaces:
Potre1qw
/
Running on Zero

File size: 1,681 Bytes
501a105
4eda14d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
910380b
4eda14d
501a105
4eda14d
 
 
 
 
 
 
 
 
 
 
 
7542834
 
4eda14d
7542834
 
4eda14d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import spaces
import gradio as gr
import numpy as np
import random
import time
import json
import os
from loguru import logger
from decouple import config
import io
import torch
import numpy as np
import torch
import cv2
from PIL import Image

from segment_anything import sam_model_registry, SamPredictor



print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
print(torch.version.cuda)
device = torch.cuda.get_device_name(torch.cuda.current_device())
print(device)

sam_checkpoint = "sam_hq_vit_h.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)


@spaces.GPU(duration=10)
def generate_image(prompt, image):
    sam.to(device=device)
    predictor = SamPredictor(sam)
    predictor.set_image(image)

    prompt = json.loads(prompt)
    input_points = np.array(prompt['input_points'])
    input_labels = np.array(prompt['input_labels'])

    mask, _, _ = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        box=None,
        multimask_output=False,
        hq_token_only=True,
    )

    rgb_array = np.zeros((mask.shape[1], mask.shape[2], 3), dtype=np.uint8)
    rgb_array[mask[0]] = 255
    result = Image.fromarray(rgb_array)

    return result


if __name__ == "__main__":
    demo = gr.Interface(fn=generate_image, inputs=[
        "text",
        gr.Image(image_mode='RGB', type="numpy")
    ],
                        outputs=[
                            gr.Image(type="numpy", image_mode='RGB')
                        ])
    demo.launch(debug=True)
    logger.debug('demo.launch()')