File size: 1,254 Bytes
d0f5c68 ca77dd1 d0f5c68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import gradio as gr
import torch
model_map = torch.hub.load('nateraw/image-generation:main', 'model_map')
on_gpu = torch.cuda.is_available()
print(f"GPU enabled? - {'🔴' if not on_gpu else '🟢'}")
class InferenceWrapper:
def __init__(self, model):
self.model = model
self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model)
def __call__(self, seed, model):
if model != self.model:
print(f"Loading model: {model}")
self.model = model
self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model)
else:
print(f"Model '{model}' already loaded, reusing it.")
return self.pipe(seed)
wrapper = InferenceWrapper('wikiart-1024')
def fn(seed, model):
return wrapper(seed, model)
gr.Interface(
fn,
inputs=[
gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed'),
gr.inputs.Radio(list(model_map), type="value", default='wikiart-1024', label='Pretrained Model')
],
outputs='image',
examples=[[343, 'wikiart-1024'], [456, 'landscapes-256'], [1234, 'stylegan3-r-ffhqu-256x256.pkl']],
enable_queue=True
).launch() |