nightfury commited on
Commit
68c1ba8
1 Parent(s): 47240fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -4,13 +4,13 @@ from PIL import Image
4
  import torch
5
  import gradio as gr
6
 
7
-
8
 
9
  model2 = torch.hub.load(
10
  "AK391/animegan2-pytorch:main",
11
  "generator",
12
  pretrained=True,
13
- device="cuda",
14
  progress=False
15
  )
16
 
@@ -18,7 +18,7 @@ model2 = torch.hub.load(
18
  model1 = torch.hub.load("AK391/animegan2-pytorch:main", "generator", pretrained="face_paint_512_v1", device="cuda")
19
  face2paint = torch.hub.load(
20
  'AK391/animegan2-pytorch:main', 'face2paint',
21
- size=512, device="cuda",side_by_side=False
22
  )
23
  def inference(img, ver):
24
  if ver == 'version 2 (🔺 robustness,🔻 stylization)':
 
4
  import torch
5
  import gradio as gr
6
 
7
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
8
 
9
  model2 = torch.hub.load(
10
  "AK391/animegan2-pytorch:main",
11
  "generator",
12
  pretrained=True,
13
+ device=DEVICE, #"cuda",
14
  progress=False
15
  )
16
 
 
18
  model1 = torch.hub.load("AK391/animegan2-pytorch:main", "generator", pretrained="face_paint_512_v1", device="cuda")
19
  face2paint = torch.hub.load(
20
  'AK391/animegan2-pytorch:main', 'face2paint',
21
+ size=512, device=DEVICE,side_by_side=False
22
  )
23
  def inference(img, ver):
24
  if ver == 'version 2 (🔺 robustness,🔻 stylization)':