nateraw commited on
Commit
d0f5c68
1 Parent(s): 7448c94

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ model_map = torch.hub.load('nateraw/image-generation:main', 'model_map')
5
+
6
+ class InferenceWrapper:
7
+ def __init__(self, model):
8
+ self.model = model
9
+ self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model)
10
+ def __call__(self, seed, model):
11
+ if model != self.model:
12
+ print(f"Loading model: {model}")
13
+ self.model = model
14
+ self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model)
15
+ else:
16
+ print(f"Model '{model}' already loaded, reusing it.")
17
+ return self.pipe(seed)
18
+
19
+ wrapper = InferenceWrapper('wikiart-1024')
20
+ def fn(seed, model):
21
+ return wrapper(seed, model)
22
+
23
+ gr.Interface(
24
+ fn,
25
+ inputs=[
26
+ gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed'),
27
+ gr.inputs.Radio(list(model_map), type="value", default='wikiart-1024', label='Pretrained Model')
28
+ ],
29
+ outputs='image',
30
+ examples=[[343, 'wikiart-1024'], [456, 'landscapes-256'], [1234, 'stylegan3-r-ffhqu-256x256.pkl']],
31
+ enable_queue=True
32
+ ).launch()