Spaces:
Build error
Build error
File size: 4,436 Bytes
66a6dc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
def spsa_func(input, params, process, i, sample_rate=24000):
return process(input.cpu(), params.cpu(), i, sample_rate).type_as(input)
class SPSAFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
input,
params,
process,
epsilon,
thread_context,
sample_rate=24000,
):
"""Apply processor to a batch of tensors using given parameters.
Args:
input (Tensor): Audio with shape: batch x 2 x samples
params (Tensor): Processor parameters with shape: batch x params
process (function): Function that will apply processing.
epsilon (float): Perturbation strength for SPSA computation.
Returns:
output (Tensor): Processed audio with same shape as input.
"""
ctx.save_for_backward(input, params)
ctx.epsilon = epsilon
ctx.process = process
ctx.thread_context = thread_context
if thread_context.parallel:
for i in range(input.shape[0]):
msg = (
"forward",
(
i,
input[i].view(-1).detach().cpu().numpy(),
params[i].view(-1).detach().cpu().numpy(),
sample_rate,
),
)
thread_context.procs[i][1].send(msg)
z = torch.empty_like(input)
for i in range(input.shape[0]):
z[i] = torch.from_numpy(thread_context.procs[i][1].recv())
else:
z = torch.empty_like(input)
for i in range(input.shape[0]):
value = (
i,
input[i].view(-1).detach().cpu().numpy(),
params[i].view(-1).detach().cpu().numpy(),
sample_rate,
)
z[i] = torch.from_numpy(
thread_context.static_forward(thread_context.dsp, value)
)
return z
@staticmethod
def backward(ctx, grad_output):
"""Estimate gradients using SPSA."""
input, params = ctx.saved_tensors
epsilon = ctx.epsilon
needs_input_grad = ctx.needs_input_grad[0]
needs_param_grad = ctx.needs_input_grad[1]
thread_context = ctx.thread_context
grads_input = None
grads_params = None
# Receive grads
if needs_input_grad:
grads_input = torch.empty_like(input)
if needs_param_grad:
grads_params = torch.empty_like(params)
if thread_context.parallel:
for i in range(input.shape[0]):
msg = (
"backward",
(
i,
input[i].view(-1).detach().cpu().numpy(),
params[i].view(-1).detach().cpu().numpy(),
needs_input_grad,
needs_param_grad,
grad_output[i].view(-1).detach().cpu().numpy(),
epsilon,
),
)
thread_context.procs[i][1].send(msg)
# Wait for output
for i in range(input.shape[0]):
temp1, temp2 = thread_context.procs[i][1].recv()
if temp1 is not None:
grads_input[i] = torch.from_numpy(temp1)
if temp2 is not None:
grads_params[i] = torch.from_numpy(temp2)
return grads_input, grads_params, None, None, None, None
else:
for i in range(input.shape[0]):
value = (
i,
input[i].view(-1).detach().cpu().numpy(),
params[i].view(-1).detach().cpu().numpy(),
needs_input_grad,
needs_param_grad,
grad_output[i].view(-1).detach().cpu().numpy(),
epsilon,
)
temp1, temp2 = thread_context.static_backward(thread_context.dsp, value)
if temp1 is not None:
grads_input[i] = torch.from_numpy(temp1)
if temp2 is not None:
grads_params[i] = torch.from_numpy(temp2)
return grads_input, grads_params, None, None, None, None
|