File size: 1,572 Bytes
0e84795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr
import random
import torch
import numpy as np
from PIL import Image, ImageOps
from fastapi import FastAPI, Request
from MagicQuill import folder_paths
from MagicQuill.llava_new import LLaVAModel
from huggingface_hub import snapshot_download
snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")

llavaModel = LLaVAModel()

def numpy_to_tensor(numpy_array):
    tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255.
    return tensor

def guess(original_image_tensor, add_color_image_tensor, add_edge_mask):
    # print("original_image_tensor:", original_image_tensor.shape)
    # print("add_color_image_tensor:", add_color_image_tensor.shape)
    # print("add_edge_mask:", add_edge_mask.shape)
    original_image_tensor = numpy_to_tensor(original_image_tensor)
    add_color_image_tensor = numpy_to_tensor(add_color_image_tensor)
    add_edge_mask = numpy_to_tensor(add_edge_mask)
    description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask)
    ans_list = []
    if ans1 and ans1 != "":
        ans_list.append(ans1)
    if ans2 and ans2 != "":
        ans_list.append(ans2)
    return ", ".join(ans_list)

# 简化 Gradio 接口,参考官方格式
gr.Interface(
    fn=guess,
    inputs=[gr.Image(label="Original Image"), 
            gr.Image(label="Colored Image"), 
            gr.Image(image_mode="L", label="Edge Mask")],
    outputs=gr.Textbox(label="Prediction Output")
).queue(max_size=40, status_update_rate=0.1).launch(max_threads=4)