Spaces:
Runtime error
Runtime error
import gradio as gr | |
from setup import setup | |
import torch | |
import gc | |
from PIL import Image | |
from transformers import AutoModel, AutoImageProcessor | |
from anime2sketch.model import Anime2Sketch | |
import spaces | |
setup() | |
print("Setup finished") | |
MLE_MODEL_REPO = "p1atdev/MangaLineExtraction-hf" | |
class MangaLineExtractor: | |
model = AutoModel.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True) | |
processor = AutoImageProcessor.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True) | |
def __call__(self, image: Image.Image) -> Image.Image: | |
inputs = self.processor(image, return_tensors="pt") | |
outputs = self.model(inputs.pixel_values) | |
line_image = Image.fromarray(outputs.pixel_values[0].numpy().astype("uint8"), mode="L") | |
return line_image | |
mle_model = MangaLineExtractor() | |
a2s_model = Anime2Sketch("./models/netG.pth", "cpu") | |
def flush(): | |
gc.collect() | |
torch.cuda.empty_cache() | |
def extract(image): | |
result = mle_model(image) | |
return result | |
def convert_to_sketch(image): | |
result = a2s_model.predict(image) | |
return result | |
def start(image): | |
return [extract(image), convert_to_sketch(Image.fromarray(image).convert("RGB"))] | |
def clear(): | |
return [None, None] | |
def ui(): | |
with gr.Blocks() as blocks: | |
gr.Markdown( | |
""" | |
# Anime to Sketch | |
Unofficial demo for converting illustrations into sketches. | |
Original repos: | |
- [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch) | |
- [Anime2Sketch](https://github.com/Mukosame/Anime2Sketch) | |
Using with 🤗 transformers: | |
- [MangaLineExtraction-hf](https://huggingface.co/p1atdev/MangaLineExtraction-hf) | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_img = gr.Image(label="Input", interactive=True) | |
extract_btn = gr.Button("Start", variant="primary") | |
clear_btn = gr.Button("Clear", variant="secondary") | |
with gr.Column(): | |
# with gr.Row(): | |
extract_output_img = gr.Image( | |
label="MangaLineExtraction", interactive=False | |
) | |
to_sketch_output_img = gr.Image(label="Anime2Sketch", interactive=False) | |
gr.Examples( | |
fn=start, | |
examples=[ | |
["./examples/0.jpg"], | |
["./examples/1.jpg"], | |
["./examples/2.jpg"], | |
], | |
inputs=[input_img], | |
outputs=[extract_output_img, to_sketch_output_img], | |
label="Examples", | |
# cache_examples=True, | |
) | |
gr.Markdown("Images are from nijijourney.") | |
extract_btn.click( | |
fn=start, | |
inputs=[input_img], | |
outputs=[extract_output_img, to_sketch_output_img], | |
) | |
clear_btn.click( | |
fn=clear, | |
inputs=[], | |
outputs=[extract_output_img, to_sketch_output_img], | |
) | |
return blocks | |
if __name__ == "__main__": | |
ui().launch() | |