Mehdi Cherti commited on
Commit
1e5aadc
1 Parent(s): c7f1d48

use native for app

Browse files
Files changed (3) hide show
  1. app.py +2 -0
  2. score_sde/op/fused_act.py +12 -9
  3. score_sde/op/upfirdn2d.py +13 -9
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import math
2
  import torch
3
  import torchvision
 
1
+ import os
2
+ os.environ["USE_NATIVE"] = "1"
3
  import math
4
  import torch
5
  import torchvision
score_sde/op/fused_act.py CHANGED
@@ -14,15 +14,18 @@ from torch.nn import functional as F
14
  from torch.autograd import Function
15
  from torch.utils.cpp_extension import load
16
 
 
17
 
18
- module_path = os.path.dirname(__file__)
19
- fused = load(
20
- "fused",
21
- sources=[
22
- os.path.join(module_path, "fused_bias_act.cpp"),
23
- os.path.join(module_path, "fused_bias_act_kernel.cu"),
24
- ],
25
- )
 
 
26
 
27
 
28
  class FusedLeakyReLUFunctionBackward(Function):
@@ -92,7 +95,7 @@ class FusedLeakyReLU(nn.Module):
92
 
93
 
94
  def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
95
- if input.device.type == "cpu":
96
  rest_dim = [1] * (input.ndim - bias.ndim - 1)
97
  return (
98
  F.leaky_relu(
 
14
  from torch.autograd import Function
15
  from torch.utils.cpp_extension import load
16
 
17
+ use_native = int(os.getenv("USE_NATIVE", "0"))
18
 
19
+
20
+ if not use_native:
21
+ module_path = os.path.dirname(__file__)
22
+ fused = load(
23
+ "fused",
24
+ sources=[
25
+ os.path.join(module_path, "fused_bias_act.cpp"),
26
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
27
+ ],
28
+ )
29
 
30
 
31
  class FusedLeakyReLUFunctionBackward(Function):
 
95
 
96
 
97
  def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
98
+ if input.device.type == "cpu" or use_native:
99
  rest_dim = [1] * (input.ndim - bias.ndim - 1)
100
  return (
101
  F.leaky_relu(
score_sde/op/upfirdn2d.py CHANGED
@@ -14,14 +14,18 @@ from torch.autograd import Function
14
  from torch.utils.cpp_extension import load
15
  from collections import abc
16
 
 
 
 
17
  module_path = os.path.dirname(__file__)
18
- upfirdn2d_op = load(
19
- "upfirdn2d",
20
- sources=[
21
- os.path.join(module_path, "upfirdn2d.cpp"),
22
- os.path.join(module_path, "upfirdn2d_kernel.cu"),
23
- ],
24
- )
 
25
 
26
 
27
  class UpFirDn2dBackward(Function):
@@ -151,7 +155,7 @@ class UpFirDn2d(Function):
151
 
152
 
153
  def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
154
- if input.device.type == "cpu":
155
  out = upfirdn2d_native(
156
  input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
157
  )
@@ -173,7 +177,7 @@ def upfirdn2d_ada(input, kernel, up=1, down=1, pad=(0, 0)):
173
  if len(pad) == 2:
174
  pad = (pad[0], pad[1], pad[0], pad[1])
175
 
176
- if input.device.type == "cpu":
177
  out = upfirdn2d_native(input, kernel, *up, *down, *pad)
178
 
179
  else:
 
14
  from torch.utils.cpp_extension import load
15
  from collections import abc
16
 
17
+ use_native = int(os.getenv("USE_NATIVE", "0"))
18
+
19
+
20
  module_path = os.path.dirname(__file__)
21
+ if not use_native:
22
+ upfirdn2d_op = load(
23
+ "upfirdn2d",
24
+ sources=[
25
+ os.path.join(module_path, "upfirdn2d.cpp"),
26
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
27
+ ],
28
+ )
29
 
30
 
31
  class UpFirDn2dBackward(Function):
 
155
 
156
 
157
  def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
158
+ if input.device.type == "cpu" or use_native:
159
  out = upfirdn2d_native(
160
  input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
161
  )
 
177
  if len(pad) == 2:
178
  pad = (pad[0], pad[1], pad[0], pad[1])
179
 
180
+ if input.device.type == "cpu" or use_native:
181
  out = upfirdn2d_native(input, kernel, *up, *down, *pad)
182
 
183
  else: