add some code
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
from . import quantization
|
||||
from . import sparsification
|
||||
@@ -0,0 +1 @@
|
||||
from .softquant import soft_quant, remove_soft_quant
|
||||
@@ -0,0 +1,113 @@
|
||||
import torch
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_optimal_scale(weight):
|
||||
with torch.no_grad():
|
||||
n_out, n_in = weight.shape
|
||||
assert n_in % 4 == 0
|
||||
if n_out % 8:
|
||||
# add padding
|
||||
pad = n_out - n_out % 8
|
||||
weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
|
||||
|
||||
weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
|
||||
weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
|
||||
scale_max = weight_max_abs / 127
|
||||
scale_sum = weight_max_sum / 129
|
||||
|
||||
scale = torch.maximum(scale_max, scale_sum)
|
||||
|
||||
return scale[:n_out]
|
||||
|
||||
@torch.no_grad()
|
||||
def q_scaled_noise(module, weight):
|
||||
if isinstance(module, torch.nn.Conv1d):
|
||||
w = weight.permute(0, 2, 1).flatten(1)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
noise[w == 0] = 0 # ignore zero entries from sparsification
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
|
||||
elif isinstance(module, torch.nn.ConvTranspose1d):
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(2, 1, 0).reshape(k * o, i)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
noise[w == 0] = 0 # ignore zero entries from sparsification
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(k, o, i).permute(2, 1, 0)
|
||||
elif len(weight.shape) == 2:
|
||||
noise = torch.rand_like(weight) - 0.5
|
||||
noise[weight == 0] = 0 # ignore zero entries from sparsification
|
||||
scale = compute_optimal_scale(weight)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError('unknown quantization setting')
|
||||
|
||||
return noise
|
||||
|
||||
class SoftQuant:
|
||||
name: str
|
||||
|
||||
def __init__(self, names: str, scale: float) -> None:
|
||||
self.names = names
|
||||
self.quantization_noise = None
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, module, inputs, *args, before=True):
|
||||
if not module.training: return
|
||||
|
||||
if before:
|
||||
self.quantization_noise = dict()
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
if self.scale is None:
|
||||
self.quantization_noise[name] = q_scaled_noise(module, weight)
|
||||
else:
|
||||
self.quantization_noise[name] = \
|
||||
self.scale * (torch.rand_like(weight) - 0.5)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight + self.quantization_noise[name]
|
||||
else:
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight - self.quantization_noise[name]
|
||||
self.quantization_noise = None
|
||||
|
||||
def apply(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant(names, scale)
|
||||
|
||||
for name in names:
|
||||
if not hasattr(module, name):
|
||||
raise ValueError("")
|
||||
|
||||
fn_before = lambda *x : fn(*x, before=True)
|
||||
fn_after = lambda *x : fn(*x, before=False)
|
||||
setattr(fn_before, 'sqm', fn)
|
||||
setattr(fn_after, 'sqm', fn)
|
||||
|
||||
|
||||
module.register_forward_pre_hook(fn_before)
|
||||
module.register_forward_hook(fn_after)
|
||||
|
||||
module
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def soft_quant(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant.apply(module, names, scale)
|
||||
return module
|
||||
|
||||
def remove_soft_quant(module, names=['weight']):
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_pre_hooks[k]
|
||||
for k, hook in module._forward_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_hooks[k]
|
||||
|
||||
return module
|
||||
@@ -0,0 +1,2 @@
|
||||
from .relegance import relegance_gradient_weighting, relegance_create_tconv_kernel, relegance_map_relevance_to_input_domain, relegance_resize_relevance_to_input_size
|
||||
from .meta_critic import MetaCritic
|
||||
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
class MetaCritic():
|
||||
def __init__(self, normalize=False, gamma=0.9, beta=0.0, joint_stats=False):
|
||||
""" Class for assessing relevance of discriminator scores
|
||||
|
||||
Args:
|
||||
gamma (float, optional): update rate for tracking discriminator stats. Defaults to 0.9.
|
||||
beta (float, optional): Miminum confidence related threshold. Defaults to 0.0.
|
||||
"""
|
||||
self.normalize = normalize
|
||||
self.gamma = gamma
|
||||
self.beta = beta
|
||||
self.joint_stats = joint_stats
|
||||
|
||||
self.disc_stats = dict()
|
||||
|
||||
def __call__(self, disc_id, real_scores, generated_scores):
|
||||
""" calculates relevance from normalized scores
|
||||
|
||||
Args:
|
||||
disc_id (any valid key): id for tracking discriminator statistics
|
||||
real_scores (torch.tensor): scores for real data
|
||||
generated_scores (torch.tensor): scores for generated data; expecting device to match real_scores.device
|
||||
|
||||
Returns:
|
||||
torch.tensor: output-domain relevance
|
||||
"""
|
||||
|
||||
if self.normalize:
|
||||
real_std = torch.std(real_scores.detach()).cpu().item()
|
||||
gen_std = torch.std(generated_scores.detach()).cpu().item()
|
||||
std = (real_std**2 + gen_std**2) ** .5
|
||||
mean = torch.mean(real_scores.detach()).cpu().item() - torch.mean(generated_scores.detach()).cpu().item()
|
||||
|
||||
key = 0 if self.joint_stats else disc_id
|
||||
|
||||
if key in self.disc_stats:
|
||||
self.disc_stats[key]['std'] = self.gamma * self.disc_stats[key]['std'] + (1 - self.gamma) * std
|
||||
self.disc_stats[key]['mean'] = self.gamma * self.disc_stats[key]['mean'] + (1 - self.gamma) * mean
|
||||
else:
|
||||
self.disc_stats[key] = {
|
||||
'std': std + 1e-5,
|
||||
'mean': mean
|
||||
}
|
||||
|
||||
std = self.disc_stats[key]['std']
|
||||
mean = self.disc_stats[key]['mean']
|
||||
else:
|
||||
mean, std = 0, 1
|
||||
|
||||
relevance = torch.relu((real_scores - generated_scores - mean) / std + mean - self.beta)
|
||||
|
||||
if False: print(f"relevance({disc_id}): {relevance.min()=} {relevance.max()=} {relevance.mean()=}")
|
||||
|
||||
return relevance
|
||||
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def view_one_hot(index, length):
|
||||
vec = length * [1]
|
||||
vec[index] = -1
|
||||
return vec
|
||||
|
||||
def create_smoothing_kernel(widths, gamma=1.5):
|
||||
""" creates a truncated gaussian smoothing kernel for the given widths
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
widths: list[Int] or torch.LongTensor
|
||||
specifies the shape of the smoothing kernel, entries must be > 0.
|
||||
|
||||
gamma: float, optional
|
||||
decay factor for gaussian relative to kernel size
|
||||
|
||||
Returns:
|
||||
--------
|
||||
kernel: torch.FloatTensor
|
||||
"""
|
||||
|
||||
widths = torch.LongTensor(widths)
|
||||
num_dims = len(widths)
|
||||
|
||||
assert(widths.min() > 0)
|
||||
|
||||
centers = widths.float() / 2 - 0.5
|
||||
sigmas = gamma * (centers + 1)
|
||||
|
||||
vals = []
|
||||
|
||||
vals= [((torch.arange(widths[i]) - centers[i]) / sigmas[i]) ** 2 for i in range(num_dims)]
|
||||
vals = sum([vals[i].view(view_one_hot(i, num_dims)) for i in range(num_dims)])
|
||||
|
||||
kernel = torch.exp(- vals)
|
||||
kernel = kernel / kernel.sum()
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def create_partition_kernel(widths, strides):
|
||||
""" creates a partition kernel for mapping a convolutional network output back to the input domain
|
||||
|
||||
Given a fully convolutional network with receptive field of shape widths and the given strides, this
|
||||
function construncts an intorpolation kernel whose tranlations by multiples of the given strides form
|
||||
a partition of one on the input domain.
|
||||
|
||||
Parameter:
|
||||
----------
|
||||
widths: list[Int] or torch.LongTensor
|
||||
shape of receptive field
|
||||
|
||||
strides: list[Int] or torch.LongTensor
|
||||
total strides of convolutional network
|
||||
|
||||
Returns:
|
||||
kernel: torch.FloatTensor
|
||||
"""
|
||||
|
||||
num_dims = len(widths)
|
||||
assert num_dims == len(strides) and num_dims in {1, 2, 3}
|
||||
|
||||
convs = {1 : F.conv1d, 2 : F.conv2d, 3 : F.conv3d}
|
||||
|
||||
widths = torch.LongTensor(widths)
|
||||
strides = torch.LongTensor(strides)
|
||||
|
||||
proto_kernel = torch.ones(torch.minimum(strides, widths).tolist())
|
||||
|
||||
# create interpolation kernel eta
|
||||
eta_widths = widths - strides + 1
|
||||
if eta_widths.min() <= 0:
|
||||
print("[create_partition_kernel] warning: receptive field does not cover input domain")
|
||||
eta_widths = torch.maximum(eta_widths, torch.ones_like(eta_widths))
|
||||
|
||||
|
||||
eta = create_smoothing_kernel(eta_widths).view(1, 1, *eta_widths.tolist())
|
||||
|
||||
padding = torch.repeat_interleave(eta_widths - 1, 2, 0).tolist()[::-1] # ordering of dimensions for padding and convolution functions is reversed in torch
|
||||
padded_proto_kernel = F.pad(proto_kernel, padding)
|
||||
padded_proto_kernel = padded_proto_kernel.view(1, 1, *padded_proto_kernel.shape)
|
||||
kernel = convs[num_dims](padded_proto_kernel, eta)
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def receptive_field(conv_model, input_shape, output_position):
|
||||
""" estimates boundaries of receptive field connected to output_position via autograd
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
conv_model: nn.Module or autograd function
|
||||
function or model implementing fully convolutional model
|
||||
|
||||
input_shape: List[Int]
|
||||
input shape ignoring batch dimension, i.e. [num_channels, dim1, dim2, ...]
|
||||
|
||||
output_position: List[Int]
|
||||
output position for which the receptive field is determined; the function raises an exception
|
||||
if output_position is out of bounds for the given input_shape.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
low: List[Int]
|
||||
start indices of receptive field
|
||||
|
||||
high: List[Int]
|
||||
stop indices of receptive field
|
||||
|
||||
"""
|
||||
|
||||
x = torch.randn((1,) + tuple(input_shape), requires_grad=True)
|
||||
y = conv_model(x)
|
||||
|
||||
# collapse channels and remove batch dimension
|
||||
y = torch.sum(y, 1)[0]
|
||||
|
||||
# create mask
|
||||
mask = torch.zeros_like(y)
|
||||
index = [torch.tensor(i) for i in output_position]
|
||||
try:
|
||||
mask.index_put_(index, torch.tensor(1, dtype=mask.dtype))
|
||||
except IndexError:
|
||||
raise ValueError('output_position out of bounds')
|
||||
|
||||
(mask * y).sum().backward()
|
||||
|
||||
# sum over channels and remove batch dimension
|
||||
grad = torch.sum(x.grad, dim=1)[0]
|
||||
tmp = torch.nonzero(grad, as_tuple=True)
|
||||
low = [t.min().item() for t in tmp]
|
||||
high = [t.max().item() for t in tmp]
|
||||
|
||||
return low, high
|
||||
|
||||
def estimate_conv_parameters(model, num_channels, num_dims, width, max_stride=10):
|
||||
""" attempts to estimate receptive field size, strides and left paddings for given model
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
model: nn.Module or autograd function
|
||||
fully convolutional model for which parameters are estimated
|
||||
|
||||
num_channels: Int
|
||||
number of input channels for model
|
||||
|
||||
num_dims: Int
|
||||
number of input dimensions for model (without channel dimension)
|
||||
|
||||
width: Int
|
||||
width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
|
||||
|
||||
max_stride: Int, optional
|
||||
assumed maximal stride of the model for any dimension, when set too low the function may fail for
|
||||
any value of width
|
||||
|
||||
Returns:
|
||||
--------
|
||||
receptive_field_size: List[Int]
|
||||
receptive field size in all dimension
|
||||
|
||||
strides: List[Int]
|
||||
stride in all dimensions
|
||||
|
||||
left_paddings: List[Int]
|
||||
left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
|
||||
|
||||
Raises:
|
||||
-------
|
||||
ValueError, KeyError
|
||||
|
||||
"""
|
||||
|
||||
input_shape = [num_channels] + num_dims * [width]
|
||||
output_position1 = num_dims * [width // (2 * max_stride)]
|
||||
output_position2 = num_dims * [width // (2 * max_stride) + 1]
|
||||
|
||||
low1, high1 = receptive_field(model, input_shape, output_position1)
|
||||
low2, high2 = receptive_field(model, input_shape, output_position2)
|
||||
|
||||
widths1 = [h - l + 1 for l, h in zip(low1, high1)]
|
||||
widths2 = [h - l + 1 for l, h in zip(low2, high2)]
|
||||
|
||||
if not all([w1 - w2 == 0 for w1, w2 in zip(widths1, widths2)]) or not all([l1 != l2 for l1, l2 in zip(low1, low2)]):
|
||||
raise ValueError("[estimate_strides]: widths to small to determine strides")
|
||||
|
||||
receptive_field_size = widths1
|
||||
strides = [l2 - l1 for l1, l2 in zip(low1, low2)]
|
||||
left_paddings = [s * p - l for l, s, p in zip(low1, strides, output_position1)]
|
||||
|
||||
return receptive_field_size, strides, left_paddings
|
||||
|
||||
def inspect_conv_model(model, num_channels, num_dims, max_width=10000, width_hint=None, stride_hint=None, verbose=False):
|
||||
""" determines size of receptive field, strides and padding probabilistically
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
model: nn.Module or autograd function
|
||||
fully convolutional model for which parameters are estimated
|
||||
|
||||
num_channels: Int
|
||||
number of input channels for model
|
||||
|
||||
num_dims: Int
|
||||
number of input dimensions for model (without channel dimension)
|
||||
|
||||
max_width: Int
|
||||
maximum width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
|
||||
|
||||
verbose: bool, optional
|
||||
if true, the function prints parameters for individual trials
|
||||
|
||||
Returns:
|
||||
--------
|
||||
receptive_field_size: List[Int]
|
||||
receptive field size in all dimension
|
||||
|
||||
strides: List[Int]
|
||||
stride in all dimensions
|
||||
|
||||
left_paddings: List[Int]
|
||||
left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
|
||||
|
||||
Raises:
|
||||
-------
|
||||
ValueError
|
||||
|
||||
"""
|
||||
|
||||
max_stride = max_width // 2
|
||||
stride = max_stride // 100
|
||||
width = max_width // 100
|
||||
|
||||
if width_hint is not None: width = 2 * width_hint
|
||||
if stride_hint is not None: stride = stride_hint
|
||||
|
||||
did_it = False
|
||||
while width < max_width and stride < max_stride:
|
||||
try:
|
||||
if verbose: print(f"[inspect_conv_model] trying parameters {width=}, {stride=}")
|
||||
receptive_field_size, strides, left_paddings = estimate_conv_parameters(model, num_channels, num_dims, width, stride)
|
||||
did_it = True
|
||||
except:
|
||||
pass
|
||||
|
||||
if did_it: break
|
||||
|
||||
width *= 2
|
||||
if width >= max_width and stride < max_stride:
|
||||
stride *= 2
|
||||
width = 2 * stride
|
||||
|
||||
if not did_it:
|
||||
raise ValueError(f'could not determine conv parameter with given max_width={max_width}')
|
||||
|
||||
return receptive_field_size, strides, left_paddings
|
||||
|
||||
|
||||
class GradWeight(torch.autograd.Function):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight):
|
||||
ctx.save_for_backward(weight)
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
weight, = ctx.saved_tensors
|
||||
|
||||
grad_input = grad_output * weight
|
||||
|
||||
return grad_input, None
|
||||
|
||||
|
||||
# API
|
||||
|
||||
def relegance_gradient_weighting(x, weight):
|
||||
"""
|
||||
|
||||
Args:
|
||||
x (torch.tensor): input tensor
|
||||
weight (torch.tensor or None): weight tensor for gradients of x; if None, no gradient weighting will be applied in backward pass
|
||||
|
||||
Returns:
|
||||
torch.tensor: the unmodified input tensor x
|
||||
|
||||
Raises:
|
||||
RuntimeError: if estimation of parameters fails due to exceeded compute budget
|
||||
"""
|
||||
if weight is None:
|
||||
return x
|
||||
else:
|
||||
return GradWeight.apply(x, weight)
|
||||
|
||||
|
||||
|
||||
def relegance_create_tconv_kernel(model, num_channels, num_dims, width_hint=None, stride_hint=None, verbose=False):
|
||||
""" creates parameters for mapping back output domain relevance to input tomain
|
||||
|
||||
Args:
|
||||
model (nn.Module or autograd.Function): fully convolutional model
|
||||
num_channels (int): number of input channels to model
|
||||
num_dims (int): number of input dimensions of model (without channel and batch dimension)
|
||||
width_hint(int or None): optional hint at maximal width of receptive field
|
||||
stride_hint(int or None): optional hint at maximal stride
|
||||
|
||||
Returns:
|
||||
dict: contains kernel, kernel dimensions, strides and left paddings for transposed convolution
|
||||
"""
|
||||
|
||||
max_width = int(100000 / (10 ** num_dims))
|
||||
|
||||
did_it = False
|
||||
try:
|
||||
receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
|
||||
did_it = True
|
||||
except:
|
||||
# try once again with larger max_width
|
||||
max_width *= 10
|
||||
|
||||
# crash if exception is raised
|
||||
try:
|
||||
if not did_it: receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
|
||||
except:
|
||||
raise RuntimeError("could not determine parameters within given compute budget")
|
||||
|
||||
partition_kernel = create_partition_kernel(receptive_field_size, strides)
|
||||
partition_kernel = torch.repeat_interleave(partition_kernel, num_channels, 1)
|
||||
|
||||
tconv_parameters = {
|
||||
'kernel': partition_kernel,
|
||||
'receptive_field_shape': receptive_field_size,
|
||||
'stride': strides,
|
||||
'left_padding': left_paddings,
|
||||
'num_dims': num_dims
|
||||
}
|
||||
|
||||
return tconv_parameters
|
||||
|
||||
|
||||
|
||||
def relegance_map_relevance_to_input_domain(od_relevance, tconv_parameters):
|
||||
""" maps output-domain relevance to input-domain relevance via transpose convolution
|
||||
|
||||
Args:
|
||||
od_relevance (torch.tensor): output-domain relevance
|
||||
tconv_parameters (dict): parameter dict as created by relegance_create_tconv_kernel
|
||||
|
||||
Returns:
|
||||
torch.tensor: input-domain relevance. The tensor is left aligned, i.e. the all-zero index of the output corresponds to the all-zero index of the discriminator input.
|
||||
Otherwise, the size of the output tensor does not need to match the size of the discriminator input. Use relegance_resize_relevance_to_input_size for a
|
||||
convenient way to adjust the output to the correct size.
|
||||
|
||||
Raises:
|
||||
ValueError: if number of dimensions is not supported
|
||||
"""
|
||||
|
||||
kernel = tconv_parameters['kernel'].to(od_relevance.device)
|
||||
rf_shape = tconv_parameters['receptive_field_shape']
|
||||
stride = tconv_parameters['stride']
|
||||
left_padding = tconv_parameters['left_padding']
|
||||
|
||||
num_dims = len(kernel.shape) - 2
|
||||
|
||||
# repeat boundary values
|
||||
od_padding = [rf_shape[i//2] // stride[i//2] + 1 for i in range(2 * num_dims)]
|
||||
padded_od_relevance = F.pad(od_relevance, od_padding[::-1], mode='replicate')
|
||||
od_padding = od_padding[::2]
|
||||
|
||||
# apply mapping and left trimming
|
||||
if num_dims == 1:
|
||||
id_relevance = F.conv_transpose1d(padded_od_relevance, kernel, stride=stride)
|
||||
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :]
|
||||
elif num_dims == 2:
|
||||
id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
|
||||
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:]
|
||||
elif num_dims == 3:
|
||||
id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
|
||||
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:, left_padding[2] + stride[2] * od_padding[2] :]
|
||||
else:
|
||||
raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
|
||||
|
||||
return id_relevance
|
||||
|
||||
|
||||
def relegance_resize_relevance_to_input_size(reference_input, relevance):
|
||||
""" adjusts size of relevance tensor to reference input size
|
||||
|
||||
Args:
|
||||
reference_input (torch.tensor): discriminator input tensor for reference
|
||||
relevance (torch.tensor): input-domain relevance corresponding to input tensor reference_input
|
||||
|
||||
Returns:
|
||||
torch.tensor: resized relevance
|
||||
|
||||
Raises:
|
||||
ValueError: if number of dimensions is not supported
|
||||
"""
|
||||
resized_relevance = torch.zeros_like(reference_input)
|
||||
|
||||
num_dims = len(reference_input.shape) - 2
|
||||
with torch.no_grad():
|
||||
if num_dims == 1:
|
||||
resized_relevance[:] = relevance[..., : min(reference_input.size(-1), relevance.size(-1))]
|
||||
elif num_dims == 2:
|
||||
resized_relevance[:] = relevance[..., : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
|
||||
elif num_dims == 3:
|
||||
resized_relevance[:] = relevance[..., : min(reference_input.size(-3), relevance.size(-3)), : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
|
||||
else:
|
||||
raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
|
||||
|
||||
return resized_relevance
|
||||
@@ -0,0 +1,6 @@
|
||||
from .gru_sparsifier import GRUSparsifier
|
||||
from .conv1d_sparsifier import Conv1dSparsifier
|
||||
from .conv_transpose1d_sparsifier import ConvTranspose1dSparsifier
|
||||
from .linear_sparsifier import LinearSparsifier
|
||||
from .common import sparsify_matrix, calculate_gru_flops_per_step
|
||||
from .utils import mark_for_sparsification, create_sparsifier
|
||||
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
class BaseSparsifier:
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
|
||||
# just copying parameters...
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.interval = interval
|
||||
self.exponent = exponent
|
||||
self.task_list = task_list
|
||||
|
||||
# ... and setting counter to 0
|
||||
self.step_counter = 0
|
||||
|
||||
def step(self, verbose=False):
|
||||
# compute current interpolation factor
|
||||
self.step_counter += 1
|
||||
|
||||
if self.step_counter < self.start:
|
||||
return
|
||||
elif self.step_counter < self.stop:
|
||||
# update only every self.interval-th interval
|
||||
if self.step_counter % self.interval:
|
||||
return
|
||||
|
||||
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
|
||||
else:
|
||||
alpha = 0
|
||||
|
||||
self.sparsify(alpha, verbose=verbose)
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
debug=True
|
||||
|
||||
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
|
||||
""" sparsifies matrix with specified block size
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
matrix : torch.tensor
|
||||
matrix to sparsify
|
||||
density : int
|
||||
target density
|
||||
block_size : [int, int]
|
||||
block size dimensions
|
||||
keep_diagonal : bool
|
||||
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
|
||||
"""
|
||||
|
||||
m, n = matrix.shape
|
||||
m1, n1 = block_size
|
||||
|
||||
if m % m1 or n % n1:
|
||||
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
|
||||
|
||||
# extract diagonal if keep_diagonal = True
|
||||
if keep_diagonal:
|
||||
if m != n:
|
||||
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
|
||||
|
||||
to_spare = torch.diag(torch.diag(matrix))
|
||||
matrix = matrix - to_spare
|
||||
else:
|
||||
to_spare = torch.zeros_like(matrix)
|
||||
|
||||
# calculate energy in sub-blocks
|
||||
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
|
||||
x = x ** 2
|
||||
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
|
||||
|
||||
number_of_blocks = (m * n) // (m1 * n1)
|
||||
number_of_survivors = round(number_of_blocks * density)
|
||||
|
||||
# masking threshold
|
||||
if number_of_survivors == 0:
|
||||
threshold = 0
|
||||
else:
|
||||
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
|
||||
|
||||
# create mask
|
||||
mask = torch.ones_like(block_energies)
|
||||
mask[block_energies < threshold] = 0
|
||||
mask = torch.repeat_interleave(mask, m1, dim=0)
|
||||
mask = torch.repeat_interleave(mask, n1, dim=1)
|
||||
|
||||
# perform masking
|
||||
masked_matrix = mask * matrix + to_spare
|
||||
|
||||
if return_mask:
|
||||
return masked_matrix, mask
|
||||
else:
|
||||
return masked_matrix
|
||||
|
||||
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
|
||||
input_size = gru.input_size
|
||||
hidden_size = gru.hidden_size
|
||||
flops = 0
|
||||
|
||||
input_density = (
|
||||
sparsification_dict.get('W_ir', [1])[0]
|
||||
+ sparsification_dict.get('W_in', [1])[0]
|
||||
+ sparsification_dict.get('W_iz', [1])[0]
|
||||
) / 3
|
||||
|
||||
recurrent_density = (
|
||||
sparsification_dict.get('W_hr', [1])[0]
|
||||
+ sparsification_dict.get('W_hn', [1])[0]
|
||||
+ sparsification_dict.get('W_hz', [1])[0]
|
||||
) / 3
|
||||
|
||||
# input matrix vector multiplications
|
||||
if not drop_input:
|
||||
flops += 2 * 3 * input_size * hidden_size * input_density
|
||||
|
||||
# recurrent matrix vector multiplications
|
||||
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
|
||||
|
||||
# biases
|
||||
flops += 6 * hidden_size
|
||||
|
||||
# activations estimated by 10 flops per activation
|
||||
flops += 30 * hidden_size
|
||||
|
||||
return flops
|
||||
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix, debug
|
||||
|
||||
|
||||
class Conv1dSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (conv1d, params), where conv1d is an instance
|
||||
of torch.nn.Conv1d and params is a tuple (density, [m, n]),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> conv = torch.nn.Conv1d(8, 16, 8)
|
||||
>>> params = (0.2, [8, 4])
|
||||
>>> sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_mask = None
|
||||
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for conv, params in self.task_list:
|
||||
# reshape weight
|
||||
if hasattr(conv, 'weight_v'):
|
||||
weight = conv.weight_v
|
||||
else:
|
||||
weight = conv.weight
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(0, 2, 1).flatten(1)
|
||||
target_density, block_size = params
|
||||
density = alpha + (1 - alpha) * target_density
|
||||
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
|
||||
w = w.reshape(i, k, o).permute(0, 2, 1)
|
||||
weight[:] = w
|
||||
|
||||
if self.last_mask is not None:
|
||||
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in conv.weight")
|
||||
|
||||
self.last_mask = new_mask
|
||||
|
||||
if verbose:
|
||||
print(f"conv1d_sparsier[{self.step_counter}]: {density=}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
import torch
|
||||
conv = torch.nn.Conv1d(8, 16, 8)
|
||||
params = (0.2, [8, 4])
|
||||
|
||||
sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 5)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(conv.weight)
|
||||
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix, debug
|
||||
|
||||
|
||||
class ConvTranspose1dSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (conv1d, params), where conv1d is an instance
|
||||
of torch.nn.Conv1d and params is a tuple (density, [m, n]),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> conv = torch.nn.ConvTranspose1d(8, 16, 8)
|
||||
>>> params = (0.2, [8, 4])
|
||||
>>> sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_mask = None
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for conv, params in self.task_list:
|
||||
# reshape weight
|
||||
if hasattr(conv, 'weight_v'):
|
||||
weight = conv.weight_v
|
||||
else:
|
||||
weight = conv.weight
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(2, 1, 0).reshape(k * o, i)
|
||||
target_density, block_size = params
|
||||
density = alpha + (1 - alpha) * target_density
|
||||
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
|
||||
w = w.reshape(k, o, i).permute(2, 1, 0)
|
||||
weight[:] = w
|
||||
|
||||
if self.last_mask is not None:
|
||||
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in conv.weight")
|
||||
|
||||
self.last_mask = new_mask
|
||||
|
||||
if verbose:
|
||||
print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
import torch
|
||||
conv = torch.nn.ConvTranspose1d(8, 16, 4, 4)
|
||||
params = (0.2, [8, 4])
|
||||
|
||||
sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 5)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(conv.weight)
|
||||
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix, debug
|
||||
|
||||
|
||||
class GRUSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
|
||||
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
|
||||
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
|
||||
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
|
||||
should be kept.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> gru = torch.nn.GRU(10, 20)
|
||||
>>> sparsify_dict = {
|
||||
... 'W_ir' : (0.5, [2, 2], False),
|
||||
... 'W_iz' : (0.6, [2, 2], False),
|
||||
... 'W_in' : (0.7, [2, 2], False),
|
||||
... 'W_hr' : (0.1, [4, 4], True),
|
||||
... 'W_hz' : (0.2, [4, 4], True),
|
||||
... 'W_hn' : (0.3, [4, 4], True),
|
||||
... }
|
||||
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for gru, params in self.task_list:
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
# input weights
|
||||
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
|
||||
if key in params:
|
||||
if hasattr(gru, 'weight_ih_l0_v'):
|
||||
weight = gru.weight_ih_l0_v
|
||||
else:
|
||||
weight = gru.weight_ih_l0
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
|
||||
weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
weight[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density, # density
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in weight_ih_l0_v")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
# recurrent weights
|
||||
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
|
||||
if key in params:
|
||||
if hasattr(gru, 'weight_hh_l0_v'):
|
||||
weight = gru.weight_hh_l0_v
|
||||
else:
|
||||
weight = gru.weight_hh_l0
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
weight[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density,
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] * new_mask == new_mask) and True:
|
||||
print("weight resurrection in weight_hh_l0_v")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
gru = torch.nn.GRU(10, 20)
|
||||
sparsify_dict = {
|
||||
'W_ir' : (0.5, [2, 2], False),
|
||||
'W_iz' : (0.6, [2, 2], False),
|
||||
'W_in' : (0.7, [2, 2], False),
|
||||
'W_hr' : (0.1, [4, 4], True),
|
||||
'W_hz' : (0.2, [4, 4], True),
|
||||
'W_hn' : (0.3, [4, 4], True),
|
||||
}
|
||||
|
||||
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(gru.weight_hh_l0)
|
||||
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix
|
||||
|
||||
|
||||
class LinearSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (linear, params), where linear is an instance
|
||||
of torch.nn.Linear and params is a tuple (density, [m, n]),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> linear = torch.nn.Linear(8, 16)
|
||||
>>> params = (0.2, [8, 4])
|
||||
>>> sparsifier = LinearSparsifier([(linear, params)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_mask = None
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for linear, params in self.task_list:
|
||||
if hasattr(linear, 'weight_v'):
|
||||
weight = linear.weight_v
|
||||
else:
|
||||
weight = linear.weight
|
||||
target_density, block_size = params
|
||||
density = alpha + (1 - alpha) * target_density
|
||||
weight[:], new_mask = sparsify_matrix(weight, density, block_size, return_mask=True)
|
||||
|
||||
if self.last_mask is not None:
|
||||
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in conv.weight")
|
||||
|
||||
self.last_mask = new_mask
|
||||
|
||||
if verbose:
|
||||
print(f"linear_sparsifier[{self.step_counter}]: {density=}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
import torch
|
||||
linear = torch.nn.Linear(8, 16)
|
||||
params = (0.2, [4, 2])
|
||||
|
||||
sparsifier = LinearSparsifier([(linear, params)], 0, 100, 5)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(linear.weight)
|
||||
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
|
||||
from dnntools.sparsification import GRUSparsifier, LinearSparsifier, Conv1dSparsifier, ConvTranspose1dSparsifier
|
||||
|
||||
def mark_for_sparsification(module, params):
|
||||
setattr(module, 'sparsify', True)
|
||||
setattr(module, 'sparsification_params', params)
|
||||
return module
|
||||
|
||||
def create_sparsifier(module, start, stop, interval):
|
||||
sparsifier_list = []
|
||||
for m in module.modules():
|
||||
if hasattr(m, 'sparsify'):
|
||||
if isinstance(m, torch.nn.GRU):
|
||||
sparsifier_list.append(
|
||||
GRUSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
elif isinstance(m, torch.nn.Linear):
|
||||
sparsifier_list.append(
|
||||
LinearSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
elif isinstance(m, torch.nn.Conv1d):
|
||||
sparsifier_list.append(
|
||||
Conv1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
elif isinstance(m, torch.nn.ConvTranspose1d):
|
||||
sparsifier_list.append(
|
||||
ConvTranspose1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
else:
|
||||
print(f"[create_sparsifier] warning: module {m} marked for sparsification but no suitable sparsifier exists.")
|
||||
|
||||
def sparsify(verbose=False):
|
||||
for sparsifier in sparsifier_list:
|
||||
sparsifier.step(verbose)
|
||||
|
||||
return sparsify
|
||||
|
||||
|
||||
def count_parameters(model, verbose=False):
|
||||
total = 0
|
||||
for name, p in model.named_parameters():
|
||||
count = torch.ones_like(p).sum().item()
|
||||
|
||||
if verbose:
|
||||
print(f"{name}: {count} parameters")
|
||||
|
||||
total += count
|
||||
|
||||
return total
|
||||
|
||||
def estimate_nonzero_parameters(module):
|
||||
num_zero_parameters = 0
|
||||
if hasattr(module, 'sparsify'):
|
||||
params = module.sparsification_params
|
||||
if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.ConvTranspose1d):
|
||||
num_zero_parameters = torch.ones_like(module.weight).sum().item() * (1 - params[0])
|
||||
elif isinstance(module, torch.nn.GRU):
|
||||
num_zero_parameters = module.input_size * module.hidden_size * (3 - params['W_ir'][0] - params['W_iz'][0] - params['W_in'][0])
|
||||
num_zero_parameters += module.hidden_size * module.hidden_size * (3 - params['W_hr'][0] - params['W_hz'][0] - params['W_hn'][0])
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
num_zero_parameters = module.in_features * module.out_features * params[0]
|
||||
else:
|
||||
raise ValueError(f'unknown sparsification method for module of type {type(module)}')
|
||||
@@ -0,0 +1 @@
|
||||
torch
|
||||
48
managed_components/78__esp-opus/dnn/torch/dnntools/setup.py
Normal file
48
managed_components/78__esp-opus/dnn/torch/dnntools/setup.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
#!/usr/bin/env/python
|
||||
import os
|
||||
from setuptools import setup
|
||||
|
||||
lib_folder = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
with open(os.path.join(lib_folder, 'requirements.txt'), 'r') as f:
|
||||
install_requires = list(f.read().splitlines())
|
||||
|
||||
print(install_requires)
|
||||
|
||||
setup(name='dnntools',
|
||||
version='1.0',
|
||||
author='Jan Buethe',
|
||||
author_email='jbuethe@amazon.de',
|
||||
description='Non-Standard tools for deep neural network training with PyTorch',
|
||||
packages=['dnntools', 'dnntools.sparsification', 'dnntools.quantization'],
|
||||
install_requires=install_requires
|
||||
)
|
||||
54
managed_components/78__esp-opus/dnn/torch/fargan/README.md
Normal file
54
managed_components/78__esp-opus/dnn/torch/fargan/README.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Framewise Auto-Regressive GAN (FARGAN)
|
||||
|
||||
Implementation of FARGAN, a low-complexity neural vocoder. Pre-trained models
|
||||
are provided as C code in the dnn/ directory with the corresponding model in
|
||||
dnn/models/ directory (name starts with fargan_). If you don't want to train
|
||||
a new FARGAN model, you can skip straight to the Inference section.
|
||||
|
||||
## Data preparation
|
||||
|
||||
For data preparation you need to build Opus as detailed in the top-level README.
|
||||
You will need to use the --enable-deep-plc configure option.
|
||||
The build will produce an executable named "dump_data".
|
||||
To prepare the training data, run:
|
||||
```
|
||||
./dump_data -train in_speech.pcm out_features.f32 out_speech.pcm
|
||||
```
|
||||
Where the in_speech.pcm speech file is a raw 16-bit PCM file sampled at 16 kHz.
|
||||
The speech data used for training the model can be found at:
|
||||
https://media.xiph.org/lpcnet/speech/tts_speech_negative_16k.sw
|
||||
|
||||
## Training
|
||||
|
||||
To perform pre-training, run the following command:
|
||||
```
|
||||
python ./train_fargan.py out_features.f32 out_speech.pcm output_dir --epochs 400 --batch-size 4096 --lr 0.002 --cuda-visible-devices 0
|
||||
```
|
||||
Once pre-training is complete, run adversarial training using:
|
||||
```
|
||||
python adv_train_fargan.py out_features.f32 out_speech.pcm output_dir --lr 0.000002 --reg-weight 5 --batch-size 160 --cuda-visible-devices 0 --initial-checkpoint output_dir/checkpoints/fargan_400.pth
|
||||
```
|
||||
The final model will be in output_dir/checkpoints/fargan_adv_50.pth.
|
||||
|
||||
The model can optionally be converted to C using:
|
||||
```
|
||||
python dump_fargan_weights.py output_dir/checkpoints/fargan_adv_50.pth fargan_c_dir
|
||||
```
|
||||
which will create a fargan_data.c and a fargan_data.h file in the fargan_c_dir directory.
|
||||
Copy these files to the opus/dnn/ directory (replacing the existing ones) and recompile Opus.
|
||||
|
||||
## Inference
|
||||
|
||||
To run the inference, start by generating the features from the audio using:
|
||||
```
|
||||
./fargan_demo -features test_speech.pcm test_features.f32
|
||||
```
|
||||
Synthesis can be achieved either using the PyTorch code or the C code.
|
||||
To synthesize from PyTorch, run:
|
||||
```
|
||||
python test_fargan.py output_dir/checkpoints/fargan_adv_50.pth test_features.f32 output_speech.pcm
|
||||
```
|
||||
To synthesize from the C code, run:
|
||||
```
|
||||
./fargan_demo -fargan-synthesis test_features.f32 output_speech.pcm
|
||||
```
|
||||
@@ -0,0 +1,278 @@
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
import sys
|
||||
import math as m
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
|
||||
import fargan
|
||||
from dataset import FARGANDataset
|
||||
from stft_loss import *
|
||||
|
||||
source_dir = os.path.split(os.path.abspath(__file__))[0]
|
||||
sys.path.append(os.path.join(source_dir, "../osce/"))
|
||||
|
||||
import models as osce_models
|
||||
|
||||
|
||||
def fmap_loss(scores_real, scores_gen):
|
||||
num_discs = len(scores_real)
|
||||
loss_feat = 0
|
||||
for k in range(num_discs):
|
||||
num_layers = len(scores_gen[k]) - 1
|
||||
f = 4 / num_discs / num_layers
|
||||
for l in range(num_layers):
|
||||
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
|
||||
|
||||
return loss_feat
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
|
||||
parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
|
||||
parser.add_argument('output', type=str, help='path to output folder')
|
||||
|
||||
parser.add_argument('--suffix', type=str, help="model name suffix", default="")
|
||||
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
|
||||
|
||||
|
||||
model_group = parser.add_argument_group(title="model parameters")
|
||||
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
|
||||
model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
|
||||
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
|
||||
|
||||
training_group = parser.add_argument_group(title="training parameters")
|
||||
training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128)
|
||||
training_group.add_argument('--lr', type=float, help='learning rate, default: 5e-4', default=5e-4)
|
||||
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 50', default=50)
|
||||
training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 60', default=60)
|
||||
training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 0.0', default=0.0)
|
||||
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
|
||||
training_group.add_argument('--reg-weight', type=float, help='regression loss weight, default: 1.0', default=1.0)
|
||||
training_group.add_argument('--fmap-weight', type=float, help='feature matchin loss weight, default: 1.0', default=1.)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cuda_visible_devices != None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
|
||||
|
||||
# checkpoints
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
checkpoint = dict()
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
|
||||
# training parameters
|
||||
batch_size = args.batch_size
|
||||
lr = args.lr
|
||||
epochs = args.epochs
|
||||
sequence_length = args.sequence_length
|
||||
lr_decay = args.lr_decay
|
||||
|
||||
adam_betas = [0.8, 0.99]
|
||||
adam_eps = 1e-8
|
||||
features_file = args.features
|
||||
signal_file = args.signal
|
||||
|
||||
# model parameters
|
||||
cond_size = args.cond_size
|
||||
|
||||
|
||||
checkpoint['batch_size'] = batch_size
|
||||
checkpoint['lr'] = lr
|
||||
checkpoint['lr_decay'] = lr_decay
|
||||
checkpoint['epochs'] = epochs
|
||||
checkpoint['sequence_length'] = sequence_length
|
||||
checkpoint['adam_betas'] = adam_betas
|
||||
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
checkpoint['model_args'] = ()
|
||||
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant}
|
||||
print(checkpoint['model_kwargs'])
|
||||
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
|
||||
|
||||
#discriminator
|
||||
disc_name = 'fdmresdisc'
|
||||
disc = osce_models.model_dict[disc_name](
|
||||
architecture='free',
|
||||
design='f_down',
|
||||
fft_sizes_16k=[2**n for n in range(6, 12)],
|
||||
freq_roi=[0, 7400],
|
||||
max_channels=256,
|
||||
noise_gain=0.0
|
||||
)
|
||||
|
||||
if type(args.initial_checkpoint) != type(None):
|
||||
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
|
||||
|
||||
dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
|
||||
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
|
||||
optimizer_disc = torch.optim.AdamW([p for p in disc.parameters() if p.requires_grad], lr=lr, betas=adam_betas, eps=adam_eps)
|
||||
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
|
||||
scheduler_disc = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer_disc, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
|
||||
|
||||
states = None
|
||||
|
||||
spect_loss = MultiResolutionSTFTLoss(device).to(device)
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
batch_count = 0
|
||||
if __name__ == '__main__':
|
||||
model.to(device)
|
||||
disc.to(device)
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
|
||||
m_r = 0
|
||||
m_f = 0
|
||||
s_r = 1
|
||||
s_f = 1
|
||||
|
||||
running_cont_loss = 0
|
||||
running_disc_loss = 0
|
||||
running_gen_loss = 0
|
||||
running_fmap_loss = 0
|
||||
running_reg_loss = 0
|
||||
running_wc = 0
|
||||
|
||||
print(f"training epoch {epoch}...")
|
||||
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
|
||||
for i, (features, periods, target, lpc) in enumerate(tepoch):
|
||||
if epoch == 1 and i == 400:
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.cond_net.parameters():
|
||||
param.requires_grad = False
|
||||
for param in model.sig_net.cond_gain_dense.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
optimizer.zero_grad()
|
||||
features = features.to(device)
|
||||
#lpc = lpc.to(device)
|
||||
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
|
||||
#lpc = fargan.interp_lpc(lpc, 4)
|
||||
periods = periods.to(device)
|
||||
if True:
|
||||
target = target[:, :sequence_length*160]
|
||||
#lpc = lpc[:,:sequence_length*4,:]
|
||||
features = features[:,:sequence_length+4,:]
|
||||
periods = periods[:,:sequence_length+4]
|
||||
else:
|
||||
target=target[::2, :]
|
||||
#lpc=lpc[::2,:]
|
||||
features=features[::2,:]
|
||||
periods=periods[::2,:]
|
||||
target = target.to(device)
|
||||
#target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
|
||||
|
||||
#nb_pre = random.randrange(1, 6)
|
||||
nb_pre = 2
|
||||
pre = target[:, :nb_pre*160]
|
||||
output, _ = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
|
||||
output = torch.cat([pre, output], -1)
|
||||
|
||||
|
||||
# discriminator update
|
||||
scores_gen = disc(output.detach().unsqueeze(1))
|
||||
scores_real = disc(target.unsqueeze(1))
|
||||
|
||||
disc_loss = 0
|
||||
for scale in scores_gen:
|
||||
disc_loss += ((scale[-1]) ** 2).mean()
|
||||
m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item()
|
||||
s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item()
|
||||
|
||||
for scale in scores_real:
|
||||
disc_loss += ((1 - scale[-1]) ** 2).mean()
|
||||
m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item()
|
||||
s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item()
|
||||
|
||||
disc_loss = 0.5 * disc_loss / len(scores_gen)
|
||||
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
|
||||
running_wc += winning_chance
|
||||
|
||||
disc.zero_grad()
|
||||
disc_loss.backward()
|
||||
optimizer_disc.step()
|
||||
|
||||
# model update
|
||||
scores_gen = disc(output.unsqueeze(1))
|
||||
if False: # todo: check whether that makes a difference
|
||||
with torch.no_grad():
|
||||
scores_real = disc(target.unsqueeze(1))
|
||||
|
||||
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], output[:, nb_pre*160:nb_pre*160+80])
|
||||
specc_loss = spect_loss(output, target.detach())
|
||||
reg_loss = (.00*cont_loss + specc_loss)
|
||||
|
||||
loss_gen = 0
|
||||
for scale in scores_gen:
|
||||
loss_gen += ((1 - scale[-1]) ** 2).mean() / len(scores_gen)
|
||||
|
||||
feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen)
|
||||
|
||||
reg_weight = args.reg_weight# + 15./(1 + (batch_count/7600.))
|
||||
gen_loss = reg_weight * reg_loss + feat_loss + loss_gen
|
||||
|
||||
model.zero_grad()
|
||||
|
||||
|
||||
gen_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
#model.clip_weights()
|
||||
|
||||
scheduler.step()
|
||||
scheduler_disc.step()
|
||||
|
||||
running_cont_loss += cont_loss.detach().cpu().item()
|
||||
running_gen_loss += loss_gen.detach().cpu().item()
|
||||
running_disc_loss += disc_loss.detach().cpu().item()
|
||||
running_fmap_loss += feat_loss.detach().cpu().item()
|
||||
running_reg_loss += reg_loss.detach().cpu().item()
|
||||
|
||||
|
||||
|
||||
tepoch.set_postfix(cont_loss=f"{running_cont_loss/(i+1):8.5f}",
|
||||
reg_weight=f"{reg_weight:8.5f}",
|
||||
gen_loss=f"{running_gen_loss/(i+1):8.5f}",
|
||||
disc_loss=f"{running_disc_loss/(i+1):8.5f}",
|
||||
fmap_loss=f"{running_fmap_loss/(i+1):8.5f}",
|
||||
reg_loss=f"{running_reg_loss/(i+1):8.5f}",
|
||||
wc = f"{running_wc/(i+1):8.5f}",
|
||||
)
|
||||
batch_count = batch_count + 1
|
||||
|
||||
# save checkpoint
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_adv_{epoch}.pth')
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['disc_sate_dict'] = disc.state_dict()
|
||||
checkpoint['loss'] = {
|
||||
'cont': running_cont_loss / len(dataloader),
|
||||
'gen': running_gen_loss / len(dataloader),
|
||||
'disc': running_disc_loss / len(dataloader),
|
||||
'fmap': running_fmap_loss / len(dataloader),
|
||||
'reg': running_reg_loss / len(dataloader)
|
||||
}
|
||||
checkpoint['epoch'] = epoch
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
61
managed_components/78__esp-opus/dnn/torch/fargan/dataset.py
Normal file
61
managed_components/78__esp-opus/dnn/torch/fargan/dataset.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import fargan
|
||||
|
||||
class FARGANDataset(torch.utils.data.Dataset):
|
||||
def __init__(self,
|
||||
feature_file,
|
||||
signal_file,
|
||||
frame_size=160,
|
||||
sequence_length=15,
|
||||
lookahead=1,
|
||||
nb_used_features=20,
|
||||
nb_features=36):
|
||||
|
||||
self.frame_size = frame_size
|
||||
self.sequence_length = sequence_length
|
||||
self.lookahead = lookahead
|
||||
self.nb_features = nb_features
|
||||
self.nb_used_features = nb_used_features
|
||||
pcm_chunk_size = self.frame_size*self.sequence_length
|
||||
|
||||
self.data = np.memmap(signal_file, dtype='int16', mode='r')
|
||||
#self.data = self.data[1::2]
|
||||
self.nb_sequences = len(self.data)//(pcm_chunk_size)-4
|
||||
self.data = self.data[(4-self.lookahead)*self.frame_size:]
|
||||
self.data = self.data[:self.nb_sequences*pcm_chunk_size]
|
||||
|
||||
|
||||
#self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size))
|
||||
sizeof = self.data.strides[-1]
|
||||
self.data = np.lib.stride_tricks.as_strided(self.data, shape=(self.nb_sequences, pcm_chunk_size*2),
|
||||
strides=(pcm_chunk_size*sizeof, sizeof))
|
||||
|
||||
self.features = np.reshape(np.memmap(feature_file, dtype='float32', mode='r'), (-1, nb_features))
|
||||
sizeof = self.features.strides[-1]
|
||||
self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length*2+4, nb_features),
|
||||
strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
|
||||
#self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
|
||||
self.periods = np.round(np.clip(256./2**(self.features[:,:,self.nb_used_features-2]+1.5), 32, 255)).astype('int')
|
||||
|
||||
self.lpc = self.features[:, :, self.nb_used_features:]
|
||||
self.features = self.features[:, :, :self.nb_used_features]
|
||||
print("lpc_size:", self.lpc.shape)
|
||||
|
||||
def __len__(self):
|
||||
return self.nb_sequences
|
||||
|
||||
def __getitem__(self, index):
|
||||
features = self.features[index, :, :].copy()
|
||||
if self.lookahead != 0:
|
||||
lpc = self.lpc[index, 4-self.lookahead:-self.lookahead, :].copy()
|
||||
else:
|
||||
lpc = self.lpc[index, 4:, :].copy()
|
||||
data = self.data[index, :].copy().astype(np.float32) / 2**15
|
||||
periods = self.periods[index, :].copy()
|
||||
#lpc = lpc*(self.gamma**np.arange(1,17))
|
||||
#lpc=lpc[None,:,:]
|
||||
#lpc = fargan.interp_lpc(lpc, 4)
|
||||
#lpc=lpc[0,:,:]
|
||||
|
||||
return features, periods, data, lpc
|
||||
@@ -0,0 +1,112 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange'))
|
||||
import wexchange.torch
|
||||
|
||||
import fargan
|
||||
#from models import model_dict
|
||||
|
||||
unquantized = [ 'cond_net.pembed', 'cond_net.fdense1', 'sig_net.cond_gain_dense', 'sig_net.gain_dense_out' ]
|
||||
|
||||
unquantized2 = [
|
||||
'cond_net.pembed',
|
||||
'cond_net.fdense1',
|
||||
'cond_net.fconv1',
|
||||
'cond_net.fconv2',
|
||||
'cont_net.0',
|
||||
'sig_net.cond_gain_dense',
|
||||
'sig_net.fwc0.conv',
|
||||
'sig_net.fwc0.glu.gate',
|
||||
'sig_net.dense1_glu.gate',
|
||||
'sig_net.gru1_glu.gate',
|
||||
'sig_net.gru2_glu.gate',
|
||||
'sig_net.gru3_glu.gate',
|
||||
'sig_net.skip_glu.gate',
|
||||
'sig_net.skip_dense',
|
||||
'sig_net.sig_dense_out',
|
||||
'sig_net.gain_dense_out'
|
||||
]
|
||||
|
||||
description=f"""
|
||||
This is an unsafe dumping script for FARGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer
|
||||
and will fail to export any other weights.
|
||||
|
||||
Furthermore, the quanitze option relies on the following explicit list of layers to be excluded:
|
||||
{unquantized}.
|
||||
|
||||
Modify this script manually if adjustments are needed.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument('weightfile', type=str, help='weight file path')
|
||||
parser.add_argument('export_folder', type=str)
|
||||
parser.add_argument('--export-filename', type=str, default='fargan_data', help='filename for source and header file (.c and .h will be added), defaults to fargan_data')
|
||||
parser.add_argument('--struct-name', type=str, default='FARGAN', help='name for C struct, defaults to FARGAN')
|
||||
parser.add_argument('--quantize', action='store_true', help='apply quantization')
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"loading weights from {args.weightfile}...")
|
||||
saved_gen= torch.load(args.weightfile, map_location='cpu')
|
||||
saved_gen['model_args'] = ()
|
||||
saved_gen['model_kwargs'] = {'cond_size': 256, 'gamma': 0.9}
|
||||
|
||||
model = fargan.FARGAN(*saved_gen['model_args'], **saved_gen['model_kwargs'])
|
||||
model.load_state_dict(saved_gen['state_dict'], strict=False)
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
model.apply(_remove_weight_norm)
|
||||
|
||||
|
||||
print("dumping model...")
|
||||
quantize_model=args.quantize
|
||||
|
||||
output_folder = args.export_folder
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name, add_typedef=True)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
|
||||
if quantize_model:
|
||||
quantize=name not in unquantized
|
||||
scale = None if quantize else 1/128
|
||||
else:
|
||||
quantize=False
|
||||
scale=1/128
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
print(f"dumping linear layer {name}...")
|
||||
wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
|
||||
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
print(f"dumping conv1d layer {name}...")
|
||||
wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
|
||||
|
||||
elif isinstance(module, nn.GRU):
|
||||
print(f"dumping GRU layer {name}...")
|
||||
wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
|
||||
|
||||
elif isinstance(module, nn.GRUCell):
|
||||
print(f"dumping GRUCell layer {name}...")
|
||||
wexchange.torch.dump_torch_grucell_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
|
||||
|
||||
elif isinstance(module, nn.Embedding):
|
||||
print(f"dumping Embedding layer {name}...")
|
||||
wexchange.torch.dump_torch_embedding_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
|
||||
#wexchange.torch.dump_torch_embedding_weights(writer, module)
|
||||
|
||||
else:
|
||||
print(f"Ignoring layer {name}...")
|
||||
|
||||
writer.close()
|
||||
346
managed_components/78__esp-opus/dnn/torch/fargan/fargan.py
Normal file
346
managed_components/78__esp-opus/dnn/torch/fargan/fargan.py
Normal file
@@ -0,0 +1,346 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import filters
|
||||
from torch.nn.utils import weight_norm
|
||||
#from convert_lsp import lpc_to_lsp, lsp_to_lpc
|
||||
from rc import lpc2rc, rc2lpc
|
||||
|
||||
source_dir = os.path.split(os.path.abspath(__file__))[0]
|
||||
sys.path.append(os.path.join(source_dir, "../dnntools"))
|
||||
from dnntools.quantization import soft_quant
|
||||
|
||||
|
||||
Fs = 16000
|
||||
|
||||
fid_dict = {}
|
||||
def dump_signal(x, filename):
|
||||
return
|
||||
if filename in fid_dict:
|
||||
fid = fid_dict[filename]
|
||||
else:
|
||||
fid = open(filename, "w")
|
||||
fid_dict[filename] = fid
|
||||
x = x.detach().numpy().astype('float32')
|
||||
x.tofile(fid)
|
||||
|
||||
|
||||
def sig_l1(y_true, y_pred):
|
||||
return torch.mean(abs(y_true-y_pred))/torch.mean(abs(y_true))
|
||||
|
||||
def sig_loss(y_true, y_pred):
|
||||
t = y_true/(1e-15+torch.norm(y_true, dim=-1, p=2, keepdim=True))
|
||||
p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True))
|
||||
return torch.mean(1.-torch.sum(p*t, dim=-1))
|
||||
|
||||
def interp_lpc(lpc, factor):
|
||||
#print(lpc.shape)
|
||||
#f = (np.arange(factor)+.5*((factor+1)%2))/factor
|
||||
lsp = torch.atanh(lpc2rc(lpc))
|
||||
#print("lsp0:")
|
||||
#print(lsp)
|
||||
shape = lsp.shape
|
||||
#print("shape is", shape)
|
||||
shape = (shape[0], shape[1]*factor, shape[2])
|
||||
interp_lsp = torch.zeros(shape, device=lpc.device)
|
||||
for k in range(factor):
|
||||
f = (k+.5*((factor+1)%2))/factor
|
||||
interp = (1-f)*lsp[:,:-1,:] + f*lsp[:,1:,:]
|
||||
interp_lsp[:,factor//2+k:-(factor//2):factor,:] = interp
|
||||
for k in range(factor//2):
|
||||
interp_lsp[:,k,:] = interp_lsp[:,factor//2,:]
|
||||
for k in range((factor+1)//2):
|
||||
interp_lsp[:,-k-1,:] = interp_lsp[:,-(factor+3)//2,:]
|
||||
#print("lsp:")
|
||||
#print(interp_lsp)
|
||||
return rc2lpc(torch.tanh(interp_lsp))
|
||||
|
||||
def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
|
||||
device = x.device
|
||||
batch_size = lpc.size(0)
|
||||
|
||||
nb_frames = lpc.shape[1]
|
||||
|
||||
|
||||
sig = torch.zeros(batch_size, subframe_size+16, device=device)
|
||||
x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size))
|
||||
out = torch.zeros((batch_size, 0), device=device)
|
||||
|
||||
#if gamma is not None:
|
||||
# bw = gamma**(torch.arange(1, 17, device=device))
|
||||
# lpc = lpc*bw[None,None,:]
|
||||
ones = torch.ones((*(lpc.shape[:-1]), 1), device=device)
|
||||
zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device)
|
||||
a = torch.cat([ones, lpc], -1)
|
||||
a_big = torch.cat([a, zeros], -1)
|
||||
fir_mat_big = filters.toeplitz_from_filter(a_big)
|
||||
|
||||
#print(a_big[:,0,:])
|
||||
for n in range(nb_frames):
|
||||
for k in range(nb_subframes):
|
||||
|
||||
sig = torch.cat([sig[:,subframe_size:], x[:,n*nb_subframes + k, :]], 1)
|
||||
exc = torch.bmm(fir_mat_big[:,n,:,:], sig[:,:,None])
|
||||
out = torch.cat([out, exc[:,-subframe_size:,0]], 1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# weight initialization and clipping
|
||||
def init_weights(module):
|
||||
if isinstance(module, nn.GRU):
|
||||
for p in module.named_parameters():
|
||||
if p[0].startswith('weight_hh_'):
|
||||
nn.init.orthogonal_(p[1])
|
||||
|
||||
def gen_phase_embedding(periods, frame_size):
|
||||
device = periods.device
|
||||
batch_size = periods.size(0)
|
||||
nb_frames = periods.size(1)
|
||||
w0 = 2*torch.pi/periods
|
||||
w0_shift = torch.cat([2*torch.pi*torch.rand((batch_size, 1), device=device)/frame_size, w0[:,:-1]], 1)
|
||||
cum_phase = frame_size*torch.cumsum(w0_shift, 1)
|
||||
fine_phase = w0[:,:,None]*torch.broadcast_to(torch.arange(frame_size, device=device), (batch_size, nb_frames, frame_size))
|
||||
embed = torch.unsqueeze(cum_phase, 2) + fine_phase
|
||||
embed = torch.reshape(embed, (batch_size, -1))
|
||||
return torch.cos(embed), torch.sin(embed)
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, feat_size, softquant=False):
|
||||
super(GLU, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
|
||||
|
||||
if softquant:
|
||||
self.gate = soft_quant(self.gate)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
|
||||
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
out = x * torch.sigmoid(self.gate(x))
|
||||
|
||||
return out
|
||||
|
||||
class FWConv(nn.Module):
|
||||
def __init__(self, in_size, out_size, kernel_size=2, softquant=False):
|
||||
super(FWConv, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.in_size = in_size
|
||||
self.kernel_size = kernel_size
|
||||
self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
|
||||
self.glu = GLU(out_size, softquant=softquant)
|
||||
|
||||
if softquant:
|
||||
self.conv = soft_quant(self.conv)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
|
||||
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x, state):
|
||||
xcat = torch.cat((state, x), -1)
|
||||
#print(x.shape, state.shape, xcat.shape, self.in_size, self.kernel_size)
|
||||
out = self.glu(torch.tanh(self.conv(xcat)))
|
||||
return out, xcat[:,self.in_size:]
|
||||
|
||||
def n(x):
|
||||
return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
|
||||
|
||||
class FARGANCond(nn.Module):
|
||||
def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12, softquant=False):
|
||||
super(FARGANCond, self).__init__()
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.cond_size = cond_size
|
||||
|
||||
self.pembed = nn.Embedding(224, pembed_dims)
|
||||
self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False)
|
||||
self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
|
||||
self.fdense2 = nn.Linear(128, 80*4, bias=False)
|
||||
|
||||
if softquant:
|
||||
self.fconv1 = soft_quant(self.fconv1)
|
||||
self.fdense2 = soft_quant(self.fdense2)
|
||||
|
||||
self.apply(init_weights)
|
||||
nb_params = sum(p.numel() for p in self.parameters())
|
||||
print(f"cond model: {nb_params} weights")
|
||||
|
||||
def forward(self, features, period):
|
||||
features = features[:,2:,:]
|
||||
period = period[:,2:]
|
||||
p = self.pembed(period-32)
|
||||
features = torch.cat((features, p), -1)
|
||||
tmp = torch.tanh(self.fdense1(features))
|
||||
tmp = tmp.permute(0, 2, 1)
|
||||
tmp = torch.tanh(self.fconv1(tmp))
|
||||
tmp = tmp.permute(0, 2, 1)
|
||||
tmp = torch.tanh(self.fdense2(tmp))
|
||||
#tmp = torch.tanh(self.fdense2(tmp))
|
||||
return tmp
|
||||
|
||||
class FARGANSub(nn.Module):
|
||||
def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, softquant=False):
|
||||
super(FARGANSub, self).__init__()
|
||||
|
||||
self.subframe_size = subframe_size
|
||||
self.nb_subframes = nb_subframes
|
||||
self.cond_size = cond_size
|
||||
self.cond_gain_dense = nn.Linear(80, 1)
|
||||
|
||||
#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
|
||||
self.fwc0 = FWConv(2*self.subframe_size+80+4, 192, softquant=softquant)
|
||||
self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False)
|
||||
self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False)
|
||||
self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)
|
||||
|
||||
self.gru1_glu = GLU(160, softquant=softquant)
|
||||
self.gru2_glu = GLU(128, softquant=softquant)
|
||||
self.gru3_glu = GLU(128, softquant=softquant)
|
||||
self.skip_glu = GLU(128, softquant=softquant)
|
||||
#self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
|
||||
|
||||
self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False)
|
||||
self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False)
|
||||
self.gain_dense_out = nn.Linear(192, 4)
|
||||
|
||||
if softquant:
|
||||
self.gru1 = soft_quant(self.gru1, names=['weight_hh', 'weight_ih'])
|
||||
self.gru2 = soft_quant(self.gru2, names=['weight_hh', 'weight_ih'])
|
||||
self.gru3 = soft_quant(self.gru3, names=['weight_hh', 'weight_ih'])
|
||||
self.skip_dense = soft_quant(self.skip_dense)
|
||||
self.sig_dense_out = soft_quant(self.sig_dense_out)
|
||||
|
||||
self.apply(init_weights)
|
||||
nb_params = sum(p.numel() for p in self.parameters())
|
||||
print(f"subframe model: {nb_params} weights")
|
||||
|
||||
def forward(self, cond, prev_pred, exc_mem, period, states, gain=None):
|
||||
device = exc_mem.device
|
||||
#print(cond.shape, prev.shape)
|
||||
|
||||
cond = n(cond)
|
||||
dump_signal(gain, 'gain0.f32')
|
||||
gain = torch.exp(self.cond_gain_dense(cond))
|
||||
dump_signal(gain, 'gain1.f32')
|
||||
idx = 256-period[:,None]
|
||||
rng = torch.arange(self.subframe_size+4, device=device)
|
||||
idx = idx + rng[None,:] - 2
|
||||
mask = idx >= 256
|
||||
idx = idx - mask*period[:,None]
|
||||
pred = torch.gather(exc_mem, 1, idx)
|
||||
pred = n(pred/(1e-5+gain))
|
||||
|
||||
prev = exc_mem[:,-self.subframe_size:]
|
||||
dump_signal(prev, 'prev_in.f32')
|
||||
prev = n(prev/(1e-5+gain))
|
||||
dump_signal(prev, 'pitch_exc.f32')
|
||||
dump_signal(exc_mem, 'exc_mem.f32')
|
||||
|
||||
tmp = torch.cat((cond, pred, prev), 1)
|
||||
#fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
|
||||
fpitch = pred[:,2:-2]
|
||||
|
||||
#tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
|
||||
fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
|
||||
fwc0_out = n(fwc0_out)
|
||||
pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out))
|
||||
|
||||
gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0])
|
||||
gru1_out = self.gru1_glu(n(gru1_state))
|
||||
gru1_out = n(gru1_out)
|
||||
gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1])
|
||||
gru2_out = self.gru2_glu(n(gru2_state))
|
||||
gru2_out = n(gru2_out)
|
||||
gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2])
|
||||
gru3_out = self.gru3_glu(n(gru3_state))
|
||||
gru3_out = n(gru3_out)
|
||||
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1)
|
||||
skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1)))
|
||||
skip_out = self.skip_glu(n(skip_out))
|
||||
sig_out = torch.tanh(self.sig_dense_out(skip_out))
|
||||
dump_signal(sig_out, 'exc_out.f32')
|
||||
#taps = self.ptaps_dense(gru3_out)
|
||||
#taps = .2*taps + torch.exp(taps)
|
||||
#taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
|
||||
#dump_signal(taps, 'taps.f32')
|
||||
|
||||
dump_signal(pitch_gain, 'pgain.f32')
|
||||
#sig_out = (sig_out + pitch_gain*fpitch) * gain
|
||||
sig_out = sig_out * gain
|
||||
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
|
||||
prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1)
|
||||
dump_signal(sig_out, 'sig_out.f32')
|
||||
return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state)
|
||||
|
||||
class FARGAN(nn.Module):
|
||||
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None, softquant=False):
|
||||
super(FARGAN, self).__init__()
|
||||
|
||||
self.subframe_size = subframe_size
|
||||
self.nb_subframes = nb_subframes
|
||||
self.frame_size = self.subframe_size*self.nb_subframes
|
||||
self.feature_dim = feature_dim
|
||||
self.cond_size = cond_size
|
||||
|
||||
self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size, softquant=softquant)
|
||||
self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, softquant=softquant)
|
||||
|
||||
def forward(self, features, period, nb_frames, pre=None, states=None):
|
||||
device = features.device
|
||||
batch_size = features.size(0)
|
||||
|
||||
prev = torch.zeros(batch_size, 256, device=device)
|
||||
exc_mem = torch.zeros(batch_size, 256, device=device)
|
||||
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
|
||||
|
||||
states = (
|
||||
torch.zeros(batch_size, 160, device=device),
|
||||
torch.zeros(batch_size, 128, device=device),
|
||||
torch.zeros(batch_size, 128, device=device),
|
||||
torch.zeros(batch_size, (2*self.subframe_size+80+4)*1, device=device)
|
||||
)
|
||||
|
||||
sig = torch.zeros((batch_size, 0), device=device)
|
||||
cond = self.cond_net(features, period)
|
||||
if pre is not None:
|
||||
exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size]
|
||||
start = 1 if nb_pre_frames>0 else 0
|
||||
for n in range(start, nb_frames+nb_pre_frames):
|
||||
for k in range(self.nb_subframes):
|
||||
pos = n*self.frame_size + k*self.subframe_size
|
||||
#print("now: ", preal.shape, prev.shape, sig_in.shape)
|
||||
pitch = period[:, 3+n]
|
||||
gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
|
||||
#gain = gain[:,:,None]
|
||||
out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain)
|
||||
|
||||
if n < nb_pre_frames:
|
||||
out = pre[:, pos:pos+self.subframe_size]
|
||||
exc_mem[:,-self.subframe_size:] = out
|
||||
else:
|
||||
sig = torch.cat([sig, out], 1)
|
||||
|
||||
states = [s.detach() for s in states]
|
||||
return sig, states
|
||||
46
managed_components/78__esp-opus/dnn/torch/fargan/filters.py
Normal file
46
managed_components/78__esp-opus/dnn/torch/fargan/filters.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
def toeplitz_from_filter(a):
|
||||
device = a.device
|
||||
L = a.size(-1)
|
||||
size0 = (*(a.shape[:-1]), L, L+1)
|
||||
size = (*(a.shape[:-1]), L, L)
|
||||
rnge = torch.arange(0, L, dtype=torch.int64, device=device)
|
||||
z = torch.tensor(0, device=device)
|
||||
idx = torch.maximum(rnge[:,None] - rnge[None,:] + 1, z)
|
||||
a = torch.cat([a[...,:1]*0, a], -1)
|
||||
#print(a)
|
||||
a = a[...,None,:]
|
||||
#print(idx)
|
||||
a = torch.broadcast_to(a, size0)
|
||||
idx = torch.broadcast_to(idx, size)
|
||||
#print(idx)
|
||||
return torch.gather(a, -1, idx)
|
||||
|
||||
def filter_iir_response(a, N):
|
||||
device = a.device
|
||||
L = a.size(-1)
|
||||
ar = a.flip(dims=(2,))
|
||||
size = (*(a.shape[:-1]), N)
|
||||
R = torch.zeros(size, device=device)
|
||||
R[:,:,0] = torch.ones((a.shape[:-1]), device=device)
|
||||
for i in range(1, L):
|
||||
R[:,:,i] = - torch.sum(ar[:,:,L-i-1:-1] * R[:,:,:i], axis=-1)
|
||||
#R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,L-i-1:-1], R[:,:,:i])
|
||||
for i in range(L, N):
|
||||
R[:,:,i] = - torch.sum(ar[:,:,:-1] * R[:,:,i-L+1:i], axis=-1)
|
||||
#R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,:-1], R[:,:,i-L+1:i])
|
||||
return R
|
||||
|
||||
if __name__ == '__main__':
|
||||
#a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]], [[1, .9, 0], [1, .8, 0]]])
|
||||
a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]]])
|
||||
A = toeplitz_from_filter(a)
|
||||
#print(A)
|
||||
R = filter_iir_response(a, 5)
|
||||
|
||||
RA = toeplitz_from_filter(R)
|
||||
print(RA)
|
||||
29
managed_components/78__esp-opus/dnn/torch/fargan/rc.py
Normal file
29
managed_components/78__esp-opus/dnn/torch/fargan/rc.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def rc2lpc(rc):
|
||||
order = rc.shape[-1]
|
||||
lpc=rc[...,0:1]
|
||||
for i in range(1, order):
|
||||
lpc = torch.cat([lpc + rc[...,i:i+1]*torch.flip(lpc,dims=(-1,)), rc[...,i:i+1]], -1)
|
||||
#print("to:", lpc)
|
||||
return lpc
|
||||
|
||||
def lpc2rc(lpc):
|
||||
order = lpc.shape[-1]
|
||||
rc = lpc[...,-1:]
|
||||
for i in range(order-1, 0, -1):
|
||||
ki = lpc[...,-1:]
|
||||
lpc = lpc[...,:-1]
|
||||
lpc = (lpc - ki*torch.flip(lpc,dims=(-1,)))/(1 - ki*ki)
|
||||
rc = torch.cat([lpc[...,-1:] , rc], -1)
|
||||
return rc
|
||||
|
||||
if __name__ == "__main__":
|
||||
rc = torch.tensor([[.5, -.5, .6, -.6]])
|
||||
print(rc)
|
||||
lpc = rc2lpc(rc)
|
||||
print(lpc)
|
||||
rc2 = lpc2rc(lpc)
|
||||
print(rc2)
|
||||
186
managed_components/78__esp-opus/dnn/torch/fargan/stft_loss.py
Normal file
186
managed_components/78__esp-opus/dnn/torch/fargan/stft_loss.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""STFT-based Loss modules."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
"""
|
||||
|
||||
#x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
|
||||
#real = x_stft[..., 0]
|
||||
#imag = x_stft[..., 1]
|
||||
|
||||
# (kan-bayashi): clamp is needed to avoid nan or inf
|
||||
#return torchaudio.functional.amplitude_to_DB(torch.abs(x_stft),db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
|
||||
#return torch.clamp(torch.abs(x_stft), min=1e-7)
|
||||
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
|
||||
return torch.clamp(torch.abs(x_stft), min=1e-7)
|
||||
|
||||
class SpectralConvergenceLoss(torch.nn.Module):
|
||||
"""Spectral convergence loss module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initilize spectral convergence loss module."""
|
||||
super(SpectralConvergenceLoss, self).__init__()
|
||||
|
||||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
"""
|
||||
x_mag = torch.sqrt(x_mag)
|
||||
y_mag = torch.sqrt(y_mag)
|
||||
return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
|
||||
|
||||
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
||||
"""Log STFT magnitude loss module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initilize los STFT magnitude loss module."""
|
||||
super(LogSTFTMagnitudeLoss, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
Returns:
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
"""
|
||||
#F.l1_loss(torch.sqrt(y_mag), torch.sqrt(x_mag)) +
|
||||
#F.l1_loss(torchaudio.functional.amplitude_to_DB(y_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80),\
|
||||
#torchaudio.functional.amplitude_to_DB(x_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80))
|
||||
|
||||
#y_mag[:,:y_mag.size(1)//2,:] = y_mag[:,:y_mag.size(1)//2,:] *0.0
|
||||
|
||||
#return F.l1_loss(torch.log(y_mag) + torch.sqrt(y_mag), torch.log(x_mag) + torch.sqrt(x_mag))
|
||||
|
||||
#return F.l1_loss(y_mag, x_mag)
|
||||
|
||||
error_loss = F.l1_loss(y, x) #+ F.l1_loss(torch.sqrt(y), torch.sqrt(x))#F.l1_loss(torch.log(y), torch.log(x))#
|
||||
|
||||
#x = torch.log(x)
|
||||
#y = torch.log(y)
|
||||
#x = x.permute(0,2,1).contiguous()
|
||||
#y = y.permute(0,2,1).contiguous()
|
||||
|
||||
'''mean_x = torch.mean(x, dim=1, keepdim=True)
|
||||
mean_y = torch.mean(y, dim=1, keepdim=True)
|
||||
|
||||
var_x = torch.var(x, dim=1, keepdim=True)
|
||||
var_y = torch.var(y, dim=1, keepdim=True)
|
||||
|
||||
std_x = torch.std(x, dim=1, keepdim=True)
|
||||
std_y = torch.std(y, dim=1, keepdim=True)
|
||||
|
||||
x_minus_mean = x - mean_x
|
||||
y_minus_mean = y - mean_y
|
||||
|
||||
pearson_corr = torch.sum(x_minus_mean * y_minus_mean, dim=1, keepdim=True) / \
|
||||
(torch.sqrt(torch.sum(x_minus_mean ** 2, dim=1, keepdim=True) + 1e-7) * \
|
||||
torch.sqrt(torch.sum(y_minus_mean ** 2, dim=1, keepdim=True) + 1e-7))
|
||||
|
||||
numerator = 2.0 * pearson_corr * std_x * std_y
|
||||
denominator = var_x + var_y + (mean_y - mean_x)**2
|
||||
|
||||
ccc = numerator/denominator
|
||||
|
||||
ccc_loss = F.l1_loss(1.0 - ccc, torch.zeros_like(ccc))'''
|
||||
|
||||
return error_loss #+ ccc_loss#+ ccc_loss
|
||||
|
||||
|
||||
class STFTLoss(torch.nn.Module):
|
||||
"""STFT loss module."""
|
||||
|
||||
def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
|
||||
"""Initialize STFT loss module."""
|
||||
super(STFTLoss, self).__init__()
|
||||
self.fft_size = fft_size
|
||||
self.shift_size = shift_size
|
||||
self.win_length = win_length
|
||||
self.window = getattr(torch, window)(win_length).to(device)
|
||||
self.spectral_convergenge_loss = SpectralConvergenceLoss()
|
||||
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
||||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T).
|
||||
y (Tensor): Groundtruth signal (B, T).
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
"""
|
||||
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
|
||||
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
|
||||
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
||||
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
||||
|
||||
return sc_loss, mag_loss
|
||||
|
||||
|
||||
class MultiResolutionSTFTLoss(torch.nn.Module):
|
||||
|
||||
'''def __init__(self,
|
||||
device,
|
||||
fft_sizes=[2048, 1024, 512, 256, 128, 64],
|
||||
hop_sizes=[512, 256, 128, 64, 32, 16],
|
||||
win_lengths=[2048, 1024, 512, 256, 128, 64],
|
||||
window="hann_window"):'''
|
||||
|
||||
'''def __init__(self,
|
||||
device,
|
||||
fft_sizes=[2048, 1024, 512, 256, 128, 64],
|
||||
hop_sizes=[256, 128, 64, 32, 16, 8],
|
||||
win_lengths=[1024, 512, 256, 128, 64, 32],
|
||||
window="hann_window"):'''
|
||||
|
||||
def __init__(self,
|
||||
device,
|
||||
fft_sizes=[2560, 1280, 640, 320, 160, 80],
|
||||
hop_sizes=[640, 320, 160, 80, 40, 20],
|
||||
win_lengths=[2560, 1280, 640, 320, 160, 80],
|
||||
window="hann_window"):
|
||||
|
||||
super(MultiResolutionSTFTLoss, self).__init__()
|
||||
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
||||
self.stft_losses = torch.nn.ModuleList()
|
||||
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
||||
self.stft_losses += [STFTLoss(device, fs, ss, wl, window)]
|
||||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T).
|
||||
y (Tensor): Groundtruth signal (B, T).
|
||||
Returns:
|
||||
Tensor: Multi resolution spectral convergence loss value.
|
||||
Tensor: Multi resolution log STFT magnitude loss value.
|
||||
"""
|
||||
sc_loss = 0.0
|
||||
mag_loss = 0.0
|
||||
for f in self.stft_losses:
|
||||
sc_l, mag_l = f(x, y)
|
||||
sc_loss += sc_l
|
||||
#mag_loss += mag_l
|
||||
sc_loss /= len(self.stft_losses)
|
||||
mag_loss /= len(self.stft_losses)
|
||||
|
||||
return sc_loss #mag_loss #+
|
||||
128
managed_components/78__esp-opus/dnn/torch/fargan/test_fargan.py
Normal file
128
managed_components/78__esp-opus/dnn/torch/fargan/test_fargan.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
|
||||
import fargan
|
||||
from dataset import FARGANDataset
|
||||
|
||||
nb_features = 36
|
||||
nb_used_features = 20
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('model', type=str, help='CELPNet model')
|
||||
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
|
||||
parser.add_argument('output', type=str, help='path to output file (16-bit PCM)')
|
||||
|
||||
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
|
||||
|
||||
|
||||
model_group = parser.add_argument_group(title="model parameters")
|
||||
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cuda_visible_devices != None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
|
||||
|
||||
|
||||
features_file = args.features
|
||||
signal_file = args.output
|
||||
|
||||
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
checkpoint = torch.load(args.model, map_location='cpu')
|
||||
|
||||
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
|
||||
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
|
||||
features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features))
|
||||
lpc = features[:,4-1:-1,nb_used_features:]
|
||||
features = features[:, :, :nb_used_features]
|
||||
#periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
|
||||
periods = np.round(np.clip(256./2**(features[:,:,nb_used_features-2]+1.5), 32, 255)).astype('int')
|
||||
|
||||
|
||||
nb_frames = features.shape[1]
|
||||
#nb_frames = 1000
|
||||
gamma = checkpoint['model_kwargs']['gamma']
|
||||
|
||||
def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
|
||||
|
||||
out = np.zeros_like(frame)
|
||||
filt = np.flip(filt)
|
||||
|
||||
inp = frame[:]
|
||||
|
||||
|
||||
for i in range(0, inp.shape[0]):
|
||||
|
||||
s = inp[i] - np.dot(buffer*weighting_vector, filt)
|
||||
|
||||
buffer[0] = s
|
||||
|
||||
buffer = np.roll(buffer, -1)
|
||||
|
||||
out[i] = s
|
||||
|
||||
return out
|
||||
|
||||
def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
|
||||
|
||||
#inverse perceptual weighting= H_preemph / W(z/gamma)
|
||||
|
||||
signal = np.zeros_like(pw_signal)
|
||||
buffer = np.zeros(16)
|
||||
num_frames = pw_signal.shape[0] //160
|
||||
assert num_frames == filters.shape[0]
|
||||
for frame_idx in range(0, num_frames):
|
||||
|
||||
in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
|
||||
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
|
||||
signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
|
||||
buffer[:] = out_sig_frame[-16:]
|
||||
return signal
|
||||
|
||||
def inverse_perceptual_weighting40 (pw_signal, filters):
|
||||
|
||||
#inverse perceptual weighting= H_preemph / W(z/gamma)
|
||||
|
||||
signal = np.zeros_like(pw_signal)
|
||||
buffer = np.zeros(16)
|
||||
num_frames = pw_signal.shape[0] //40
|
||||
assert num_frames == filters.shape[0]
|
||||
for frame_idx in range(0, num_frames):
|
||||
in_frame = pw_signal[frame_idx*40: (frame_idx+1)*40][:]
|
||||
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer)
|
||||
signal[frame_idx*40: (frame_idx+1)*40] = out_sig_frame[:]
|
||||
buffer[:] = out_sig_frame[-16:]
|
||||
return signal
|
||||
|
||||
from scipy.signal import lfilter
|
||||
|
||||
if __name__ == '__main__':
|
||||
model.to(device)
|
||||
features = torch.tensor(features).to(device)
|
||||
#lpc = torch.tensor(lpc).to(device)
|
||||
periods = torch.tensor(periods).to(device)
|
||||
weighting = gamma**np.arange(1, 17)
|
||||
lpc = lpc*weighting
|
||||
lpc = fargan.interp_lpc(torch.tensor(lpc), 4).numpy()
|
||||
|
||||
sig, _ = model(features, periods, nb_frames - 4)
|
||||
#weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
|
||||
sig = sig.detach().numpy().flatten()
|
||||
sig = lfilter(np.array([1.]), np.array([1., -.85]), sig)
|
||||
#sig = inverse_perceptual_weighting40(sig, lpc[0,:,:])
|
||||
|
||||
pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
|
||||
pcm.tofile(signal_file)
|
||||
169
managed_components/78__esp-opus/dnn/torch/fargan/train_fargan.py
Normal file
169
managed_components/78__esp-opus/dnn/torch/fargan/train_fargan.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
|
||||
import fargan
|
||||
from dataset import FARGANDataset
|
||||
from stft_loss import *
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
|
||||
parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
|
||||
parser.add_argument('output', type=str, help='path to output folder')
|
||||
|
||||
parser.add_argument('--suffix', type=str, help="model name suffix", default="")
|
||||
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
|
||||
|
||||
|
||||
model_group = parser.add_argument_group(title="model parameters")
|
||||
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
|
||||
model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
|
||||
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
|
||||
|
||||
training_group = parser.add_argument_group(title="training parameters")
|
||||
training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
|
||||
training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3)
|
||||
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20)
|
||||
training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15)
|
||||
training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4)
|
||||
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cuda_visible_devices != None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
|
||||
|
||||
# checkpoints
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
checkpoint = dict()
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
|
||||
# training parameters
|
||||
batch_size = args.batch_size
|
||||
lr = args.lr
|
||||
epochs = args.epochs
|
||||
sequence_length = args.sequence_length
|
||||
lr_decay = args.lr_decay
|
||||
|
||||
adam_betas = [0.8, 0.95]
|
||||
adam_eps = 1e-8
|
||||
features_file = args.features
|
||||
signal_file = args.signal
|
||||
|
||||
# model parameters
|
||||
cond_size = args.cond_size
|
||||
|
||||
|
||||
checkpoint['batch_size'] = batch_size
|
||||
checkpoint['lr'] = lr
|
||||
checkpoint['lr_decay'] = lr_decay
|
||||
checkpoint['epochs'] = epochs
|
||||
checkpoint['sequence_length'] = sequence_length
|
||||
checkpoint['adam_betas'] = adam_betas
|
||||
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
checkpoint['model_args'] = ()
|
||||
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant}
|
||||
print(checkpoint['model_kwargs'])
|
||||
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
|
||||
#model = fargan.FARGAN()
|
||||
#model = nn.DataParallel(model)
|
||||
|
||||
if type(args.initial_checkpoint) != type(None):
|
||||
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
|
||||
|
||||
dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
|
||||
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
|
||||
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
|
||||
|
||||
states = None
|
||||
|
||||
spect_loss = MultiResolutionSTFTLoss(device).to(device)
|
||||
|
||||
if __name__ == '__main__':
|
||||
model.to(device)
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
|
||||
running_specc = 0
|
||||
running_cont_loss = 0
|
||||
running_loss = 0
|
||||
|
||||
print(f"training epoch {epoch}...")
|
||||
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
|
||||
for i, (features, periods, target, lpc) in enumerate(tepoch):
|
||||
optimizer.zero_grad()
|
||||
features = features.to(device)
|
||||
#lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4))
|
||||
#print("interp size", lpc.shape)
|
||||
#lpc = lpc.to(device)
|
||||
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
|
||||
#lpc = fargan.interp_lpc(lpc, 4)
|
||||
periods = periods.to(device)
|
||||
if (np.random.rand() > 0.1):
|
||||
target = target[:, :sequence_length*160]
|
||||
#lpc = lpc[:,:sequence_length*4,:]
|
||||
features = features[:,:sequence_length+4,:]
|
||||
periods = periods[:,:sequence_length+4]
|
||||
else:
|
||||
target=target[::2, :]
|
||||
#lpc=lpc[::2,:]
|
||||
features=features[::2,:]
|
||||
periods=periods[::2,:]
|
||||
target = target.to(device)
|
||||
#print(target.shape, lpc.shape)
|
||||
#target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
|
||||
|
||||
#nb_pre = random.randrange(1, 6)
|
||||
nb_pre = 2
|
||||
pre = target[:, :nb_pre*160]
|
||||
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
|
||||
sig = torch.cat([pre, sig], -1)
|
||||
|
||||
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160])
|
||||
specc_loss = spect_loss(sig, target.detach())
|
||||
loss = .03*cont_loss + specc_loss
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
#model.clip_weights()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
running_specc += specc_loss.detach().cpu().item()
|
||||
running_cont_loss += cont_loss.detach().cpu().item()
|
||||
|
||||
running_loss += loss.detach().cpu().item()
|
||||
tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
|
||||
cont_loss=f"{running_cont_loss/(i+1):8.5f}",
|
||||
specc=f"{running_specc/(i+1):8.5f}",
|
||||
)
|
||||
|
||||
# save checkpoint
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth')
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['loss'] = running_loss / len(dataloader)
|
||||
checkpoint['epoch'] = epoch
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange'))
|
||||
import wexchange.torch
|
||||
|
||||
from models import model_dict
|
||||
|
||||
unquantized = [
|
||||
'bfcc_with_corr_upsampler.fc',
|
||||
'cont_net.0',
|
||||
'fwc6.cont_fc.0',
|
||||
'fwc6.fc.0',
|
||||
'fwc6.fc.1.gate',
|
||||
'fwc7.cont_fc.0',
|
||||
'fwc7.fc.0',
|
||||
'fwc7.fc.1.gate'
|
||||
]
|
||||
|
||||
description=f"""
|
||||
This is an unsafe dumping script for FWGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer
|
||||
and will fail to export any other weights.
|
||||
|
||||
Furthermore, the quanitze option relies on the following explicit list of layers to be excluded:
|
||||
{unquantized}.
|
||||
|
||||
Modify this script manually if adjustments are needed.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name')
|
||||
parser.add_argument('weightfile', type=str, help='weight file path')
|
||||
parser.add_argument('export_folder', type=str)
|
||||
parser.add_argument('--export-filename', type=str, default='fwgan_data', help='filename for source and header file (.c and .h will be added), defaults to fwgan_data')
|
||||
parser.add_argument('--struct-name', type=str, default='FWGAN', help='name for C struct, defaults to FWGAN')
|
||||
parser.add_argument('--quantize', action='store_true', help='apply quantization')
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
model = model_dict[args.model]()
|
||||
|
||||
print(f"loading weights from {args.weightfile}...")
|
||||
saved_gen= torch.load(args.weightfile, map_location='cpu')
|
||||
model.load_state_dict(saved_gen)
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
model.apply(_remove_weight_norm)
|
||||
|
||||
|
||||
print("dumping model...")
|
||||
quantize_model=args.quantize
|
||||
|
||||
output_folder = args.export_folder
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
|
||||
if quantize_model:
|
||||
quantize=name not in unquantized
|
||||
scale = None if quantize else 1/128
|
||||
else:
|
||||
quantize=False
|
||||
scale=1/128
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
print(f"dumping linear layer {name}...")
|
||||
wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
|
||||
|
||||
if isinstance(module, nn.Conv1d):
|
||||
print(f"dumping conv1d layer {name}...")
|
||||
wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
|
||||
|
||||
if isinstance(module, nn.GRU):
|
||||
print(f"dumping GRU layer {name}...")
|
||||
wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
|
||||
|
||||
writer.close()
|
||||
141
managed_components/78__esp-opus/dnn/torch/fwgan/inference.py
Normal file
141
managed_components/78__esp-opus/dnn/torch/fwgan/inference.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy import signal as si
|
||||
from scipy.io import wavfile
|
||||
import argparse
|
||||
|
||||
from models import model_dict
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name')
|
||||
parser.add_argument('weightfile', type=str, help='weight file')
|
||||
parser.add_argument('input', type=str, help='input: feature file or folder with feature files')
|
||||
parser.add_argument('output', type=str, help='output: wav file name or folder name, depending on input')
|
||||
|
||||
|
||||
########################### Signal Processing Layers ###########################
|
||||
|
||||
def preemphasis(x, coef= -0.85):
|
||||
|
||||
return si.lfilter(np.array([1.0, coef]), np.array([1.0]), x).astype('float32')
|
||||
|
||||
def deemphasis(x, coef= -0.85):
|
||||
|
||||
return si.lfilter(np.array([1.0]), np.array([1.0, coef]), x).astype('float32')
|
||||
|
||||
gamma = 0.92
|
||||
weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
|
||||
|
||||
|
||||
def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
|
||||
|
||||
out = np.zeros_like(frame)
|
||||
|
||||
filt = np.flip(filt)
|
||||
|
||||
inp = frame[:]
|
||||
|
||||
|
||||
for i in range(0, inp.shape[0]):
|
||||
|
||||
s = inp[i] - np.dot(buffer*weighting_vector, filt)
|
||||
|
||||
buffer[0] = s
|
||||
|
||||
buffer = np.roll(buffer, -1)
|
||||
|
||||
out[i] = s
|
||||
|
||||
return out
|
||||
|
||||
def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
|
||||
|
||||
#inverse perceptual weighting= H_preemph / W(z/gamma)
|
||||
|
||||
pw_signal = preemphasis(pw_signal)
|
||||
|
||||
signal = np.zeros_like(pw_signal)
|
||||
buffer = np.zeros(16)
|
||||
num_frames = pw_signal.shape[0] //160
|
||||
assert num_frames == filters.shape[0]
|
||||
|
||||
for frame_idx in range(0, num_frames):
|
||||
|
||||
in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
|
||||
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
|
||||
signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
|
||||
buffer[:] = out_sig_frame[-16:]
|
||||
|
||||
return signal
|
||||
|
||||
|
||||
def process_item(generator, feature_filename, output_filename, verbose=False):
|
||||
|
||||
feat = np.memmap(feature_filename, dtype='float32', mode='r')
|
||||
|
||||
num_feat_frames = len(feat) // 36
|
||||
feat = np.reshape(feat, (num_feat_frames, 36))
|
||||
|
||||
bfcc = np.copy(feat[:, :18])
|
||||
corr = np.copy(feat[:, 19:20]) + 0.5
|
||||
bfcc_with_corr = torch.from_numpy(np.hstack((bfcc, corr))).type(torch.FloatTensor).unsqueeze(0)#.to(device)
|
||||
|
||||
period = torch.from_numpy((0.1 + 50 * np.copy(feat[:, 18:19]) + 100)\
|
||||
.astype('int32')).type(torch.long).view(1,-1)#.to(device)
|
||||
|
||||
lpc_filters = np.copy(feat[:, -16:])
|
||||
|
||||
start_time = time.time()
|
||||
x1 = generator(period, bfcc_with_corr, torch.zeros(1,320)) #this means the vocoder runs in complete synthesis mode with zero history audio frames
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
x1 = x1.squeeze(1).squeeze(0).detach().cpu().numpy()
|
||||
gen_seconds = len(x1)/16000
|
||||
out = deemphasis(inverse_perceptual_weighting(x1, lpc_filters, weighting_vector))
|
||||
if verbose:
|
||||
print(f"Took {total_time:.3f}s to generate {len(x1)} samples ({gen_seconds}s) -> {gen_seconds/total_time:.2f}x real time")
|
||||
|
||||
out = np.clip(np.round(2**15 * out), -2**15, 2**15 -1).astype(np.int16)
|
||||
wavfile.write(output_filename, 16000, out)
|
||||
|
||||
|
||||
########################### The inference loop over folder containing lpcnet feature files #################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
generator = model_dict[args.model]()
|
||||
|
||||
|
||||
#Load the FWGAN500Hz Checkpoint
|
||||
saved_gen= torch.load(args.weightfile, map_location='cpu')
|
||||
generator.load_state_dict(saved_gen)
|
||||
|
||||
#this is just to remove the weight_norm from the model layers as it's no longer needed
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
generator.apply(_remove_weight_norm)
|
||||
|
||||
#enable inference mode
|
||||
generator = generator.eval()
|
||||
|
||||
print('Successfully loaded the generator model ... start generation:')
|
||||
|
||||
if os.path.isdir(args.input):
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
for fn in os.listdir(args.input):
|
||||
print(f"processing input {fn}...")
|
||||
feature_filename = os.path.join(args.input, fn)
|
||||
output_filename = os.path.join(args.output, os.path.splitext(fn)[0] + f"_{args.model}.wav")
|
||||
process_item(generator, feature_filename, output_filename)
|
||||
else:
|
||||
process_item(generator, args.input, args.output)
|
||||
|
||||
print("Finished!")
|
||||
@@ -0,0 +1,7 @@
|
||||
from .fwgan400 import FWGAN400ContLarge
|
||||
from .fwgan500 import FWGAN500Cont
|
||||
|
||||
model_dict = {
|
||||
'fwgan400': FWGAN400ContLarge,
|
||||
'fwgan500': FWGAN500Cont
|
||||
}
|
||||
@@ -0,0 +1,308 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
import numpy as np
|
||||
|
||||
which_norm = weight_norm
|
||||
|
||||
#################### Definition of basic model components ####################
|
||||
|
||||
#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet)
|
||||
class ConvLookahead(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False):
|
||||
super(ConvLookahead, self).__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.padding_left = (kernel_size - 2) * dilation
|
||||
self.padding_right = 1 * dilation
|
||||
|
||||
self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = F.pad(x,(self.padding_left, self.padding_right))
|
||||
conv_out = self.conv(x)
|
||||
return conv_out
|
||||
|
||||
#(modified) GLU Activation layer definition
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, feat_size):
|
||||
super(GLU, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
|
||||
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
out = torch.tanh(x) * torch.sigmoid(self.gate(x))
|
||||
|
||||
return out
|
||||
|
||||
#GRU layer definition
|
||||
class ContForwardGRU(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers=1):
|
||||
super(ContForwardGRU, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.hidden_size, bias=False)),
|
||||
nn.Tanh())
|
||||
|
||||
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\
|
||||
bias=False)
|
||||
|
||||
self.nl = GLU(self.hidden_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x, x0):
|
||||
|
||||
self.gru.flatten_parameters()
|
||||
|
||||
h0 = self.cont_fc(x0).unsqueeze(0)
|
||||
|
||||
output, h0 = self.gru(x, h0)
|
||||
|
||||
return self.nl(output)
|
||||
|
||||
# Framewise convolution layer definition
|
||||
class ContFramewiseConv(torch.nn.Module):
|
||||
|
||||
def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True):
|
||||
|
||||
super(ContFramewiseConv, self).__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.frame_kernel_size = frame_kernel_size
|
||||
self.frame_len = frame_len
|
||||
|
||||
if (causal == True) or (self.frame_kernel_size == 2):
|
||||
|
||||
self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len
|
||||
self.required_pad_right = 0
|
||||
|
||||
self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.required_pad_left, bias=False)),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len
|
||||
self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len
|
||||
|
||||
self.fc_input_dim = self.frame_kernel_size * self.frame_len
|
||||
self.fc_out_dim = out_dim
|
||||
|
||||
if act=='glu':
|
||||
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
|
||||
GLU(self.fc_out_dim)
|
||||
)
|
||||
if act=='tanh':
|
||||
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
|
||||
isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x, x0):
|
||||
|
||||
if self.frame_kernel_size == 1:
|
||||
return self.fc(x)
|
||||
|
||||
x_flat = x.reshape(x.size(0),1,-1)
|
||||
pad = self.cont_fc(x0).view(x0.size(0),1,-1)
|
||||
x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2)
|
||||
|
||||
x_flat_padded_unfolded = F.unfold(x_flat_padded,\
|
||||
kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous()
|
||||
|
||||
out = self.fc(x_flat_padded_unfolded)
|
||||
return out
|
||||
|
||||
# A fully-connected based upsampling layer definition
|
||||
class UpsampleFC(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, upsample_factor):
|
||||
super(UpsampleFC, self).__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
self.upsample_factor = upsample_factor
|
||||
self.fc = nn.Linear(in_ch, out_ch * upsample_factor, bias=False)
|
||||
self.nl = nn.Tanh()
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or\
|
||||
isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
batch_size = x.size(0)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.nl(self.fc(x))
|
||||
x = x.reshape((batch_size, -1, self.out_ch))
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
########################### The complete model definition #################################
|
||||
|
||||
class FWGAN400ContLarge(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.bfcc_with_corr_upsampler = UpsampleFC(19,80,4)
|
||||
|
||||
self.feat_in_conv1 = ConvLookahead(160,256,kernel_size=5)
|
||||
self.feat_in_nl1 = GLU(256)
|
||||
|
||||
self.cont_net = nn.Sequential(which_norm(nn.Linear(321, 160, bias=False)),
|
||||
nn.Tanh(),
|
||||
which_norm(nn.Linear(160, 160, bias=False)),
|
||||
nn.Tanh(),
|
||||
which_norm(nn.Linear(160, 80, bias=False)),
|
||||
nn.Tanh(),
|
||||
which_norm(nn.Linear(80, 80, bias=False)),
|
||||
nn.Tanh(),
|
||||
which_norm(nn.Linear(80, 64, bias=False)),
|
||||
nn.Tanh(),
|
||||
which_norm(nn.Linear(64, 64, bias=False)),
|
||||
nn.Tanh())
|
||||
|
||||
self.rnn = ContForwardGRU(256,256)
|
||||
|
||||
self.fwc1 = ContFramewiseConv(256, 256)
|
||||
self.fwc2 = ContFramewiseConv(256, 128)
|
||||
self.fwc3 = ContFramewiseConv(128, 128)
|
||||
self.fwc4 = ContFramewiseConv(128, 64)
|
||||
self.fwc5 = ContFramewiseConv(64, 64)
|
||||
self.fwc6 = ContFramewiseConv(64, 40)
|
||||
self.fwc7 = ContFramewiseConv(40, 40)
|
||||
|
||||
self.init_weights()
|
||||
self.count_parameters()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
|
||||
isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def count_parameters(self):
|
||||
num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n")
|
||||
|
||||
def create_phase_signals(self, periods):
|
||||
|
||||
batch_size = periods.size(0)
|
||||
progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
|
||||
progression = torch.repeat_interleave(progression, batch_size, 0)
|
||||
|
||||
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
|
||||
chunks = []
|
||||
for sframe in range(periods.size(1)):
|
||||
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
|
||||
|
||||
chunk_sin = torch.sin(f * progression + phase0)
|
||||
chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,40)
|
||||
|
||||
chunk_cos = torch.cos(f * progression + phase0)
|
||||
chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,40)
|
||||
|
||||
chunk = torch.cat((chunk_sin, chunk_cos), dim = -1)
|
||||
|
||||
phase0 = phase0 + 160 * f
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
phase_signals = torch.cat(chunks, dim=1)
|
||||
|
||||
return phase_signals
|
||||
|
||||
|
||||
def gain_multiply(self, x, c0):
|
||||
|
||||
gain = 10**(0.5*c0/np.sqrt(18.0))
|
||||
gain = torch.repeat_interleave(gain, 160, dim=-1)
|
||||
gain = gain.reshape(gain.size(0),1,-1).squeeze(1)
|
||||
|
||||
return x * gain
|
||||
|
||||
def forward(self, pitch_period, bfcc_with_corr, x0):
|
||||
|
||||
norm_x0 = torch.norm(x0,2, dim=-1, keepdim=True)
|
||||
x0 = x0 / torch.sqrt((1e-8) + norm_x0**2)
|
||||
x0 = torch.cat((torch.log(norm_x0 + 1e-7), x0), dim=-1)
|
||||
|
||||
p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous()
|
||||
|
||||
envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous())
|
||||
|
||||
feat_in = torch.cat((p_embed , envelope), dim=1)
|
||||
|
||||
wav_latent1 = self.feat_in_nl1(self.feat_in_conv1(feat_in).permute(0,2,1).contiguous())
|
||||
|
||||
cont_latent = self.cont_net(x0)
|
||||
|
||||
rnn_out = self.rnn(wav_latent1, cont_latent)
|
||||
|
||||
fwc1_out = self.fwc1(rnn_out, cont_latent)
|
||||
|
||||
fwc2_out = self.fwc2(fwc1_out, cont_latent)
|
||||
|
||||
fwc3_out = self.fwc3(fwc2_out, cont_latent)
|
||||
|
||||
fwc4_out = self.fwc4(fwc3_out, cont_latent)
|
||||
|
||||
fwc5_out = self.fwc5(fwc4_out, cont_latent)
|
||||
|
||||
fwc6_out = self.fwc6(fwc5_out, cont_latent)
|
||||
|
||||
fwc7_out = self.fwc7(fwc6_out, cont_latent)
|
||||
|
||||
waveform = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1)
|
||||
|
||||
waveform = self.gain_multiply(waveform,bfcc_with_corr[:,:,:1])
|
||||
|
||||
return waveform
|
||||
@@ -0,0 +1,260 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
import numpy as np
|
||||
|
||||
|
||||
which_norm = weight_norm
|
||||
|
||||
#################### Definition of basic model components ####################
|
||||
|
||||
#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet)
|
||||
class ConvLookahead(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False):
|
||||
super(ConvLookahead, self).__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.padding_left = (kernel_size - 2) * dilation
|
||||
self.padding_right = 1 * dilation
|
||||
|
||||
self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = F.pad(x,(self.padding_left, self.padding_right))
|
||||
conv_out = self.conv(x)
|
||||
return conv_out
|
||||
|
||||
#(modified) GLU Activation layer definition
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, feat_size):
|
||||
super(GLU, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
|
||||
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
out = torch.tanh(x) * torch.sigmoid(self.gate(x))
|
||||
|
||||
return out
|
||||
|
||||
#GRU layer definition
|
||||
class ContForwardGRU(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers=1):
|
||||
super(ContForwardGRU, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
#This is to initialize the layer with history audio samples for continuation.
|
||||
self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.hidden_size, bias=False)),
|
||||
nn.Tanh())
|
||||
|
||||
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\
|
||||
bias=False)
|
||||
|
||||
self.nl = GLU(self.hidden_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x, x0):
|
||||
|
||||
self.gru.flatten_parameters()
|
||||
|
||||
h0 = self.cont_fc(x0).unsqueeze(0)
|
||||
|
||||
output, h0 = self.gru(x, h0)
|
||||
|
||||
return self.nl(output)
|
||||
|
||||
# Framewise convolution layer definition
|
||||
class ContFramewiseConv(torch.nn.Module):
|
||||
|
||||
def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True):
|
||||
|
||||
super(ContFramewiseConv, self).__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.frame_kernel_size = frame_kernel_size
|
||||
self.frame_len = frame_len
|
||||
|
||||
if (causal == True) or (self.frame_kernel_size == 2):
|
||||
|
||||
self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len
|
||||
self.required_pad_right = 0
|
||||
|
||||
#This is to initialize the layer with history audio samples for continuation.
|
||||
self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.required_pad_left, bias=False)),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
else:
|
||||
#This means non-causal frame-wise convolution. We don't use it at the moment
|
||||
self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len
|
||||
self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len
|
||||
|
||||
self.fc_input_dim = self.frame_kernel_size * self.frame_len
|
||||
self.fc_out_dim = out_dim
|
||||
|
||||
if act=='glu':
|
||||
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
|
||||
GLU(self.fc_out_dim)
|
||||
)
|
||||
if act=='tanh':
|
||||
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
|
||||
isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x, x0):
|
||||
|
||||
if self.frame_kernel_size == 1:
|
||||
return self.fc(x)
|
||||
|
||||
x_flat = x.reshape(x.size(0),1,-1)
|
||||
pad = self.cont_fc(x0).view(x0.size(0),1,-1)
|
||||
x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2)
|
||||
|
||||
x_flat_padded_unfolded = F.unfold(x_flat_padded,\
|
||||
kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous()
|
||||
|
||||
out = self.fc(x_flat_padded_unfolded)
|
||||
return out
|
||||
|
||||
########################### The complete model definition #################################
|
||||
|
||||
class FWGAN500Cont(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
|
||||
#PrecondNet:
|
||||
self.bfcc_with_corr_upsampler = nn.Sequential(nn.ConvTranspose1d(19,64,kernel_size=5,stride=5,padding=0,\
|
||||
bias=False),
|
||||
nn.Tanh())
|
||||
|
||||
self.feat_in_conv = ConvLookahead(128,256,kernel_size=5)
|
||||
self.feat_in_nl = GLU(256)
|
||||
|
||||
#GRU:
|
||||
self.rnn = ContForwardGRU(256,256)
|
||||
|
||||
#Frame-wise convolution stack:
|
||||
self.fwc1 = ContFramewiseConv(256, 256)
|
||||
self.fwc2 = ContFramewiseConv(256, 128)
|
||||
self.fwc3 = ContFramewiseConv(128, 128)
|
||||
self.fwc4 = ContFramewiseConv(128, 64)
|
||||
self.fwc5 = ContFramewiseConv(64, 64)
|
||||
self.fwc6 = ContFramewiseConv(64, 32)
|
||||
self.fwc7 = ContFramewiseConv(32, 32, act='tanh')
|
||||
|
||||
self.init_weights()
|
||||
self.count_parameters()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
|
||||
isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def count_parameters(self):
|
||||
num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n")
|
||||
|
||||
def create_phase_signals(self, periods):
|
||||
|
||||
batch_size = periods.size(0)
|
||||
progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
|
||||
progression = torch.repeat_interleave(progression, batch_size, 0)
|
||||
|
||||
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
|
||||
chunks = []
|
||||
for sframe in range(periods.size(1)):
|
||||
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
|
||||
|
||||
chunk_sin = torch.sin(f * progression + phase0)
|
||||
chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,32)
|
||||
|
||||
chunk_cos = torch.cos(f * progression + phase0)
|
||||
chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,32)
|
||||
|
||||
chunk = torch.cat((chunk_sin, chunk_cos), dim = -1)
|
||||
|
||||
phase0 = phase0 + 160 * f
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
phase_signals = torch.cat(chunks, dim=1)
|
||||
|
||||
return phase_signals
|
||||
|
||||
|
||||
def gain_multiply(self, x, c0):
|
||||
|
||||
gain = 10**(0.5*c0/np.sqrt(18.0))
|
||||
gain = torch.repeat_interleave(gain, 160, dim=-1)
|
||||
gain = gain.reshape(gain.size(0),1,-1).squeeze(1)
|
||||
|
||||
return x * gain
|
||||
|
||||
def forward(self, pitch_period, bfcc_with_corr, x0):
|
||||
|
||||
#This should create a latent representation of shape [Batch_dim, 500 frames, 256 elemets per frame]
|
||||
p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous()
|
||||
envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous())
|
||||
feat_in = torch.cat((p_embed , envelope), dim=1)
|
||||
wav_latent = self.feat_in_nl(self.feat_in_conv(feat_in).permute(0,2,1).contiguous())
|
||||
|
||||
#Generation with continuation using history samples x0 starts from here:
|
||||
|
||||
rnn_out = self.rnn(wav_latent, x0)
|
||||
|
||||
fwc1_out = self.fwc1(rnn_out, x0)
|
||||
fwc2_out = self.fwc2(fwc1_out, x0)
|
||||
fwc3_out = self.fwc3(fwc2_out, x0)
|
||||
fwc4_out = self.fwc4(fwc3_out, x0)
|
||||
fwc5_out = self.fwc5(fwc4_out, x0)
|
||||
fwc6_out = self.fwc6(fwc5_out, x0)
|
||||
fwc7_out = self.fwc7(fwc6_out, x0)
|
||||
|
||||
waveform_unscaled = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1)
|
||||
waveform = self.gain_multiply(waveform_unscaled,bfcc_with_corr[:,:,:1])
|
||||
|
||||
return waveform
|
||||
27
managed_components/78__esp-opus/dnn/torch/lossgen/README.md
Normal file
27
managed_components/78__esp-opus/dnn/torch/lossgen/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
#Packet loss simulator
|
||||
|
||||
This code is an attempt at simulating better packet loss scenarios. The most common way of simulating
|
||||
packet loss is to use a random sequence where each packet loss event is uncorrelated with previous events.
|
||||
That is a simplistic model since we know that losses often occur in bursts. This model uses real data
|
||||
to build a generative model for packet loss.
|
||||
|
||||
We use the training data provided for the Audio Deep Packet Loss Concealment Challenge, which is available at:
|
||||
|
||||
http://plcchallenge2022pub.blob.core.windows.net/plcchallengearchive/test_train.tar.gz
|
||||
|
||||
To create the training data, run:
|
||||
|
||||
`./process_data.sh /<path>/test_train/train/lossy_signals/`
|
||||
|
||||
That will create an ascii loss\_sorted.txt file with all loss data sorted in increasing packet loss
|
||||
percentage. Then just run:
|
||||
|
||||
`python ./train_lossgen.py`
|
||||
|
||||
to train a model
|
||||
|
||||
To generate a sequence, run
|
||||
|
||||
`python3 ./test_lossgen.py <checkpoint> <percentage> output.txt --length 10000`
|
||||
|
||||
where <checkpoint> is the .pth model file and <percentage> is the amount of loss (e.g. 0.2 for 20% loss).
|
||||
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('checkpoint', type=str, help='model checkpoint')
|
||||
parser.add_argument('output_dir', type=str, help='output folder')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import lossgen
|
||||
from wexchange.torch import dump_torch_weights
|
||||
from wexchange.c_export import CWriter, print_vector
|
||||
|
||||
def c_export(args, model):
|
||||
|
||||
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
|
||||
|
||||
writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False, add_typedef=True)
|
||||
writer.header.write(
|
||||
f"""
|
||||
#include "opus_types.h"
|
||||
"""
|
||||
)
|
||||
|
||||
dense_layers = [
|
||||
('dense_in', "lossgen_dense_in"),
|
||||
('dense_out', "lossgen_dense_out")
|
||||
]
|
||||
|
||||
|
||||
for name, export_name in dense_layers:
|
||||
layer = model.get_submodule(name)
|
||||
dump_torch_weights(writer, layer, name=export_name, verbose=True, quantize=False, scale=None)
|
||||
|
||||
|
||||
gru_layers = [
|
||||
("gru1", "lossgen_gru1"),
|
||||
("gru2", "lossgen_gru2"),
|
||||
]
|
||||
|
||||
max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=True, scale=None, recurrent_scale=None)
|
||||
for name, export_name in gru_layers])
|
||||
|
||||
writer.header.write(
|
||||
f"""
|
||||
|
||||
#define LOSSGEN_MAX_RNN_UNITS {max_rnn_units}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
#model = LossGen()
|
||||
#checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
#model.load_state_dict(checkpoint['state_dict'])
|
||||
c_export(args, model)
|
||||
29
managed_components/78__esp-opus/dnn/torch/lossgen/lossgen.py
Normal file
29
managed_components/78__esp-opus/dnn/torch/lossgen/lossgen.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LossGen(nn.Module):
|
||||
def __init__(self, gru1_size=16, gru2_size=16):
|
||||
super(LossGen, self).__init__()
|
||||
|
||||
self.gru1_size = gru1_size
|
||||
self.gru2_size = gru2_size
|
||||
self.dense_in = nn.Linear(2, 8)
|
||||
self.gru1 = nn.GRU(8, self.gru1_size, batch_first=True)
|
||||
self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True)
|
||||
self.dense_out = nn.Linear(self.gru2_size, 1)
|
||||
|
||||
def forward(self, loss, perc, states=None):
|
||||
#print(states)
|
||||
device = loss.device
|
||||
batch_size = loss.size(0)
|
||||
if states is None:
|
||||
gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device)
|
||||
gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device)
|
||||
else:
|
||||
gru1_state = states[0]
|
||||
gru2_state = states[1]
|
||||
x = torch.tanh(self.dense_in(torch.cat([loss, perc], dim=-1)))
|
||||
gru1_out, gru1_state = self.gru1(x, gru1_state)
|
||||
gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
|
||||
return self.dense_out(gru2_out), [gru1_state, gru2_state]
|
||||
@@ -0,0 +1,17 @@
|
||||
#!/bin/sh
|
||||
|
||||
#directory containing the loss files
|
||||
datadir=$1
|
||||
|
||||
for i in $datadir/*_is_lost.txt
|
||||
do
|
||||
perc=`cat $i | awk '{a+=$1}END{print a/NR}'`
|
||||
echo $perc $i
|
||||
done > percentage_list.txt
|
||||
|
||||
sort -n percentage_list.txt | awk '{print $2}' > percentage_sorted.txt
|
||||
|
||||
for i in `cat percentage_sorted.txt`
|
||||
do
|
||||
cat $i
|
||||
done > loss_sorted.txt
|
||||
@@ -0,0 +1,42 @@
|
||||
import lossgen
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('model', type=str, help='CELPNet model')
|
||||
parser.add_argument('percentage', type=float, help='percentage loss')
|
||||
parser.add_argument('output', type=str, help='path to output file (ascii)')
|
||||
|
||||
parser.add_argument('--length', type=int, help="length of sequence to generate", default=500)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
|
||||
checkpoint = torch.load(args.model, map_location='cpu')
|
||||
model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
|
||||
states=None
|
||||
last = torch.zeros((1,1,1))
|
||||
perc = torch.tensor((args.percentage,))[None,None,:]
|
||||
seq = torch.zeros((0,1,1))
|
||||
|
||||
one = torch.ones((1,1,1))
|
||||
zero = torch.zeros((1,1,1))
|
||||
|
||||
if __name__ == '__main__':
|
||||
for i in range(args.length):
|
||||
prob, states = model(last, perc, states=states)
|
||||
prob = torch.sigmoid(prob)
|
||||
states[0] = states[0].detach()
|
||||
states[1] = states[1].detach()
|
||||
loss = one if np.random.rand() < prob else zero
|
||||
last = loss
|
||||
seq = torch.cat([seq, loss])
|
||||
|
||||
np.savetxt(args.output, seq[:,:,0].numpy().astype('int'), fmt='%d')
|
||||
@@ -0,0 +1,99 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
from scipy.signal import lfilter
|
||||
import os
|
||||
import lossgen
|
||||
|
||||
class LossDataset(torch.utils.data.Dataset):
|
||||
def __init__(self,
|
||||
loss_file,
|
||||
sequence_length=997):
|
||||
|
||||
self.sequence_length = sequence_length
|
||||
|
||||
self.loss = np.loadtxt(loss_file, dtype='float32')
|
||||
|
||||
self.nb_sequences = self.loss.shape[0]//self.sequence_length
|
||||
self.loss = self.loss[:self.nb_sequences*self.sequence_length]
|
||||
self.perc = lfilter(np.array([.001], dtype='float32'), np.array([1., -.999], dtype='float32'), self.loss)
|
||||
|
||||
self.loss = np.reshape(self.loss, (self.nb_sequences, self.sequence_length, 1))
|
||||
self.perc = np.reshape(self.perc, (self.nb_sequences, self.sequence_length, 1))
|
||||
|
||||
def __len__(self):
|
||||
return self.nb_sequences
|
||||
|
||||
def __getitem__(self, index):
|
||||
r0 = np.random.normal(scale=.1, size=(1,1)).astype('float32')
|
||||
r1 = np.random.normal(scale=.1, size=(self.sequence_length,1)).astype('float32')
|
||||
perc = self.perc[index, :, :]
|
||||
perc = perc + (r0+r1)*perc*(1-perc)
|
||||
return [self.loss[index, :, :], perc]
|
||||
|
||||
|
||||
adam_betas = [0.8, 0.98]
|
||||
adam_eps = 1e-8
|
||||
batch_size=256
|
||||
lr_decay = 0.001
|
||||
lr = 0.003
|
||||
epsilon = 1e-5
|
||||
epochs = 2000
|
||||
checkpoint_dir='checkpoint'
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
checkpoint = dict()
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
checkpoint['model_args'] = ()
|
||||
checkpoint['model_kwargs'] = {'gru1_size': 16, 'gru2_size': 32}
|
||||
model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
dataset = LossDataset('loss_sorted.txt')
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
|
||||
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
|
||||
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model.to(device)
|
||||
states = None
|
||||
for epoch in range(1, epochs + 1):
|
||||
|
||||
running_loss = 0
|
||||
|
||||
print(f"training epoch {epoch}...")
|
||||
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
|
||||
for i, (loss, perc) in enumerate(tepoch):
|
||||
optimizer.zero_grad()
|
||||
loss = loss.to(device)
|
||||
perc = perc.to(device)
|
||||
|
||||
out, states = model(loss, perc, states=states)
|
||||
states = [state.detach() for state in states]
|
||||
out = torch.sigmoid(out[:,:-1,:])
|
||||
target = loss[:,1:,:]
|
||||
|
||||
loss = torch.mean(-target*torch.log(out+epsilon) - (1-target)*torch.log(1-out+epsilon))
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
running_loss += loss.detach().cpu().item()
|
||||
tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
|
||||
)
|
||||
|
||||
# save checkpoint
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f'lossgen_{epoch}.pth')
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['loss'] = running_loss / len(dataloader)
|
||||
checkpoint['epoch'] = epoch
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
27
managed_components/78__esp-opus/dnn/torch/lpcnet/README.md
Normal file
27
managed_components/78__esp-opus/dnn/torch/lpcnet/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# LPCNet
|
||||
|
||||
Incomplete pytorch implementation of LPCNet
|
||||
|
||||
## Data preparation
|
||||
For data preparation use dump_data in github.com/xiph/LPCNet. To turn this into
|
||||
a training dataset, copy data and feature file to a folder and run
|
||||
|
||||
python add_dataset_config.py my_dataset_folder
|
||||
|
||||
|
||||
## Training
|
||||
To train a model, create and adjust a setup file, e.g. with
|
||||
|
||||
python make_default_setup.py my_setup.yml --path2dataset my_dataset_folder
|
||||
|
||||
Then simply run
|
||||
|
||||
python train_lpcnet.py my_setup.yml my_output
|
||||
|
||||
## Inference
|
||||
Create feature file with dump_data from github.com/xiph/LPCNet. Then run e.g.
|
||||
|
||||
python test_lpcnet.py features.f32 my_output/checkpoints/checkpoint_ep_10.pth out.wav
|
||||
|
||||
Inference runs on CPU and takes usually between 3 and 20 seconds per generated second of audio,
|
||||
depending on the CPU.
|
||||
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
from utils.templates import dataset_template_v1, dataset_template_v2
|
||||
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser("add_dataset_config.py")
|
||||
|
||||
parser.add_argument('path', type=str, help='path to folder containing feature and data file')
|
||||
parser.add_argument('--version', type=int, help="dataset version, 1 for classic LPCNet with 55 feature slots, 2 for new format with 36 feature slots.", default=2)
|
||||
parser.add_argument('--description', type=str, help='brief dataset description', default="I will add a description later")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.version == 1:
|
||||
template = dataset_template_v1
|
||||
data_extension = '.u8'
|
||||
elif args.version == 2:
|
||||
template = dataset_template_v2
|
||||
data_extension = '.s16'
|
||||
else:
|
||||
raise ValueError(f"unknown dataset version {args.version}")
|
||||
|
||||
# get folder content
|
||||
content = os.listdir(args.path)
|
||||
|
||||
features = [c for c in content if c.endswith('.f32')]
|
||||
|
||||
if len(features) != 1:
|
||||
print("could not determine feature file")
|
||||
else:
|
||||
template['feature_file'] = features[0]
|
||||
|
||||
data = [c for c in content if c.endswith(data_extension)]
|
||||
if len(data) != 1:
|
||||
print("could not determine data file")
|
||||
else:
|
||||
template['signal_file'] = data[0]
|
||||
|
||||
template['description'] = args.description
|
||||
|
||||
with open(os.path.join(args.path, 'info.yml'), 'w') as f:
|
||||
yaml.dump(template, f)
|
||||
@@ -0,0 +1 @@
|
||||
from .lpcnet_dataset import LPCNetDataset
|
||||
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" Dataset for LPCNet training """
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
scale = 255.0/32768.0
|
||||
scale_1 = 32768.0/255.0
|
||||
def ulaw2lin(u):
|
||||
u = u - 128
|
||||
s = np.sign(u)
|
||||
u = np.abs(u)
|
||||
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
|
||||
|
||||
|
||||
def lin2ulaw(x):
|
||||
s = np.sign(x)
|
||||
x = np.abs(x)
|
||||
u = (s*(128*np.log(1+scale*x)/np.log(256)))
|
||||
u = np.clip(128 + np.round(u), 0, 255)
|
||||
return u
|
||||
|
||||
|
||||
def run_lpc(signal, lpcs, frame_length=160):
|
||||
num_frames, lpc_order = lpcs.shape
|
||||
|
||||
prediction = np.concatenate(
|
||||
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
|
||||
)
|
||||
error = signal[lpc_order :] - prediction
|
||||
|
||||
return prediction, error
|
||||
|
||||
class LPCNetDataset(Dataset):
|
||||
def __init__(self,
|
||||
path_to_dataset,
|
||||
features=['cepstrum', 'periods', 'pitch_corr'],
|
||||
input_signals=['last_signal', 'prediction', 'last_error'],
|
||||
target='error',
|
||||
frames_per_sample=15,
|
||||
feature_history=2,
|
||||
feature_lookahead=2,
|
||||
lpc_gamma=1):
|
||||
|
||||
super(LPCNetDataset, self).__init__()
|
||||
|
||||
# load dataset info
|
||||
self.path_to_dataset = path_to_dataset
|
||||
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
|
||||
dataset = yaml.load(f, yaml.FullLoader)
|
||||
|
||||
# dataset version
|
||||
self.version = dataset['version']
|
||||
if self.version == 1:
|
||||
self.getitem = self.getitem_v1
|
||||
elif self.version == 2:
|
||||
self.getitem = self.getitem_v2
|
||||
else:
|
||||
raise ValueError(f"dataset version {self.version} unknown")
|
||||
|
||||
# features
|
||||
self.feature_history = feature_history
|
||||
self.feature_lookahead = feature_lookahead
|
||||
self.frame_offset = 1 + self.feature_history
|
||||
self.frames_per_sample = frames_per_sample
|
||||
self.input_features = features
|
||||
self.feature_frame_layout = dataset['feature_frame_layout']
|
||||
self.lpc_gamma = lpc_gamma
|
||||
|
||||
# load feature file
|
||||
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
|
||||
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
|
||||
self.feature_frame_length = dataset['feature_frame_length']
|
||||
|
||||
assert len(self.features) % self.feature_frame_length == 0
|
||||
self.features = self.features.reshape((-1, self.feature_frame_length))
|
||||
|
||||
# derive number of samples is dataset
|
||||
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1) // self.frames_per_sample
|
||||
|
||||
# signals
|
||||
self.frame_length = dataset['frame_length']
|
||||
self.signal_frame_layout = dataset['signal_frame_layout']
|
||||
self.input_signals = input_signals
|
||||
self.target = target
|
||||
|
||||
# load signals
|
||||
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
|
||||
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
|
||||
self.signal_frame_length = dataset['signal_frame_length']
|
||||
self.signals = self.signals.reshape((-1, self.signal_frame_length))
|
||||
assert len(self.signals) == len(self.features) * self.frame_length
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
|
||||
def getitem_v2(self, index):
|
||||
sample = dict()
|
||||
|
||||
# extract features
|
||||
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||
|
||||
for feature in self.input_features:
|
||||
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||
|
||||
# convert periods
|
||||
if 'periods' in self.input_features:
|
||||
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||
|
||||
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||
|
||||
# last_signal and signal are always expected to be there
|
||||
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
|
||||
|
||||
# calculate prediction and error if lpc coefficients present and prediction not given
|
||||
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
|
||||
# lpc coefficients with one frame lookahead
|
||||
# frame positions (start one frame early for past excitation)
|
||||
frame_start = self.frame_offset + self.frames_per_sample * index - 1
|
||||
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
|
||||
|
||||
# feature positions
|
||||
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
|
||||
lpc_order = lpc_stop - lpc_start
|
||||
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
|
||||
|
||||
# LPC weighting
|
||||
lpc_order = lpc_stop - lpc_start
|
||||
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
|
||||
lpcs = lpcs * weights
|
||||
|
||||
# signal position (lpc_order samples as history)
|
||||
signal_start = frame_start * self.frame_length - lpc_order + 1
|
||||
signal_stop = frame_stop * self.frame_length + 1
|
||||
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
|
||||
|
||||
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
|
||||
|
||||
# extract signals
|
||||
offset = self.frame_length
|
||||
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
|
||||
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
|
||||
# calculate error between real signal and noisy prediction
|
||||
|
||||
|
||||
sample['error'] = sample['signal'] - sample['prediction']
|
||||
|
||||
|
||||
# concatenate features
|
||||
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||
signals = torch.cat([torch.LongTensor(lin2ulaw(sample[key])).unsqueeze(-1) for key in self.input_signals], dim=-1)
|
||||
target = torch.LongTensor(lin2ulaw(sample[self.target]))
|
||||
periods = torch.LongTensor(sample['periods'])
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
|
||||
|
||||
def getitem_v1(self, index):
|
||||
sample = dict()
|
||||
|
||||
# extract features
|
||||
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||
|
||||
for feature in self.input_features:
|
||||
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||
|
||||
# convert periods
|
||||
if 'periods' in self.input_features:
|
||||
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||
|
||||
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||
|
||||
# last_signal and signal are always expected to be there
|
||||
for signal_name, index in self.signal_frame_layout.items():
|
||||
sample[signal_name] = self.signals[signal_start : signal_stop, index]
|
||||
|
||||
# concatenate features
|
||||
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
|
||||
target = torch.LongTensor(sample[self.target])
|
||||
periods = torch.LongTensor(sample['periods'])
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_length
|
||||
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
|
||||
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
# gru states
|
||||
gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
|
||||
gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
|
||||
gru_states = [gru_a_state, gru_b_state]
|
||||
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# set gradients to zero
|
||||
optimizer.zero_grad()
|
||||
|
||||
# zero out initial gru states
|
||||
gru_a_state.zero_()
|
||||
gru_b_state.zero_()
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
|
||||
|
||||
# calculate loss
|
||||
loss = criterion(output.permute(0, 2, 1), target)
|
||||
|
||||
# calculate gradients
|
||||
loss.backward()
|
||||
|
||||
# update weights
|
||||
optimizer.step()
|
||||
|
||||
# update learning rate
|
||||
scheduler.step()
|
||||
|
||||
# call sparsifier
|
||||
model.sparsify()
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
|
||||
def evaluate(model, criterion, dataloader, device, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
# gru states
|
||||
gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
|
||||
gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
|
||||
gru_states = [gru_a_state, gru_b_state]
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
|
||||
# zero out initial gru states
|
||||
gru_a_state.zero_()
|
||||
gru_b_state.zero_()
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
|
||||
|
||||
# calculate loss
|
||||
loss = criterion(output.permute(0, 2, 1), target)
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
from utils.templates import setup_dict
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('name', type=str, help='name of default setup file')
|
||||
parser.add_argument('--model', choices=['lpcnet', 'multi_rate'], help='LPCNet model name', default='lpcnet')
|
||||
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setup = setup_dict[args.model]
|
||||
|
||||
# update dataset if given
|
||||
if type(args.path2dataset) != type(None):
|
||||
setup['dataset'] = args.path2dataset
|
||||
|
||||
name = args.name
|
||||
if not name.endswith('.yml'):
|
||||
name += '.yml'
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open(name, 'w') as f:
|
||||
f.write(yaml.dump(setup))
|
||||
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config_name", type=str, help="name of config file (.yml will be appended)")
|
||||
parser.add_argument("test_name", type=str, help="name for test result display")
|
||||
parser.add_argument("checkpoint", type=str, help="checkpoint to test")
|
||||
parser.add_argument("--lpcnet-demo", type=str, help="path to lpcnet_demo binary, default: /local/code/LPCNet/lpcnet_demo", default="/local/code/LPCNet/lpcnet_demo")
|
||||
parser.add_argument("--lpcnext-path", type=str, help="path to lpcnext folder, defalut: dirname(__file__)", default=os.path.dirname(__file__))
|
||||
parser.add_argument("--python-exe", type=str, help='python executable path, default: sys.executable', default=sys.executable)
|
||||
parser.add_argument("--pad", type=str, help="left pad of output in seconds, default: 0.015", default="0.015")
|
||||
parser.add_argument("--trim", type=str, help="left trim of output in seconds, default: 0", default="0")
|
||||
|
||||
|
||||
|
||||
template='''
|
||||
test: "{NAME}"
|
||||
processing:
|
||||
- "sox {{INPUT}} {{INPUT}}.raw"
|
||||
- "{LPCNET_DEMO} -features {{INPUT}}.raw {{INPUT}}.features.f32"
|
||||
- "{PYTHON} {WORKING}/test_lpcnet.py {{INPUT}}.features.f32 {CHECKPOINT} {{OUTPUT}}.ua.wav"
|
||||
- "sox {{OUTPUT}}.ua.wav {{OUTPUT}}.uap.wav pad {PAD}"
|
||||
- "sox {{OUTPUT}}.uap.wav {{OUTPUT}} trim {TRIM}"
|
||||
- "rm {{INPUT}}.raw {{OUTPUT}}.uap.wav {{OUTPUT}}.ua.wav {{INPUT}}.features.f32"
|
||||
'''
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
file_content = template.format(
|
||||
NAME=args.test_name,
|
||||
LPCNET_DEMO=os.path.abspath(args.lpcnet_demo),
|
||||
PYTHON=os.path.abspath(args.python_exe),
|
||||
PAD=args.pad,
|
||||
TRIM=args.trim,
|
||||
WORKING=os.path.abspath(args.lpcnext_path),
|
||||
CHECKPOINT=os.path.abspath(args.checkpoint)
|
||||
)
|
||||
|
||||
print(file_content)
|
||||
|
||||
filename = args.config_name
|
||||
if not filename.endswith(".yml"):
|
||||
filename += ".yml"
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write(file_content)
|
||||
@@ -0,0 +1,8 @@
|
||||
from .lpcnet import LPCNet
|
||||
from .multi_rate_lpcnet import MultiRateLPCNet
|
||||
|
||||
|
||||
model_dict = {
|
||||
'lpcnet' : LPCNet,
|
||||
'multi_rate' : MultiRateLPCNet
|
||||
}
|
||||
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from utils.ulaw import lin2ulawq, ulaw2lin
|
||||
from utils.sample import sample_excitation
|
||||
from utils.pcm import clip_to_int16
|
||||
from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
|
||||
from utils.layers import DualFC
|
||||
from utils.misc import get_pdf_from_tree
|
||||
|
||||
|
||||
class LPCNet(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(LPCNet, self).__init__()
|
||||
|
||||
#
|
||||
self.input_layout = config['input_layout']
|
||||
self.feature_history = config['feature_history']
|
||||
self.feature_lookahead = config['feature_lookahead']
|
||||
|
||||
# frame rate network parameters
|
||||
self.feature_dimension = config['feature_dimension']
|
||||
self.period_embedding_dim = config['period_embedding_dim']
|
||||
self.period_levels = config['period_levels']
|
||||
self.feature_channels = self.feature_dimension + self.period_embedding_dim
|
||||
self.feature_conditioning_dim = config['feature_conditioning_dim']
|
||||
self.feature_conv_kernel_size = config['feature_conv_kernel_size']
|
||||
|
||||
|
||||
# frame rate network layers
|
||||
self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
|
||||
self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
|
||||
self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
|
||||
self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
|
||||
self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
|
||||
|
||||
# sample rate network parameters
|
||||
self.frame_size = config['frame_size']
|
||||
self.signal_levels = config['signal_levels']
|
||||
self.signal_embedding_dim = config['signal_embedding_dim']
|
||||
self.gru_a_units = config['gru_a_units']
|
||||
self.gru_b_units = config['gru_b_units']
|
||||
self.output_levels = config['output_levels']
|
||||
self.hsampling = config.get('hsampling', False)
|
||||
|
||||
self.gru_a_input_dim = len(self.input_layout['signals']) * self.signal_embedding_dim + self.feature_conditioning_dim
|
||||
self.gru_b_input_dim = self.gru_a_units + self.feature_conditioning_dim
|
||||
|
||||
# sample rate network layers
|
||||
self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
|
||||
self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
|
||||
self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
|
||||
self.dual_fc = DualFC(self.gru_b_units, self.output_levels)
|
||||
|
||||
# sparsification
|
||||
self.sparsifier = []
|
||||
|
||||
# GRU A
|
||||
if 'gru_a' in config['sparsification']:
|
||||
gru_config = config['sparsification']['gru_a']
|
||||
task_list = [(self.gru_a, gru_config['params'])]
|
||||
self.sparsifier.append(GRUSparsifier(task_list,
|
||||
gru_config['start'],
|
||||
gru_config['stop'],
|
||||
gru_config['interval'],
|
||||
gru_config['exponent'])
|
||||
)
|
||||
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
|
||||
gru_config['params'], drop_input=True)
|
||||
else:
|
||||
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
|
||||
|
||||
# GRU B
|
||||
if 'gru_b' in config['sparsification']:
|
||||
gru_config = config['sparsification']['gru_b']
|
||||
task_list = [(self.gru_b, gru_config['params'])]
|
||||
self.sparsifier.append(GRUSparsifier(task_list,
|
||||
gru_config['start'],
|
||||
gru_config['stop'],
|
||||
gru_config['interval'],
|
||||
gru_config['exponent'])
|
||||
)
|
||||
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
|
||||
gru_config['params'])
|
||||
else:
|
||||
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
|
||||
|
||||
# inference parameters
|
||||
self.lpc_gamma = config.get('lpc_gamma', 1)
|
||||
|
||||
def sparsify(self):
|
||||
for sparsifier in self.sparsifier:
|
||||
sparsifier.step()
|
||||
|
||||
def get_gflops(self, fs, verbose=False):
|
||||
gflops = 0
|
||||
|
||||
# frame rate network
|
||||
conditioning_dim = self.feature_conditioning_dim
|
||||
feature_channels = self.feature_channels
|
||||
frame_rate = fs / self.frame_size
|
||||
frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
|
||||
if verbose:
|
||||
print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
|
||||
gflops += frame_rate_network_complexity
|
||||
|
||||
# gru a
|
||||
gru_a_rate = fs
|
||||
gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
|
||||
if verbose:
|
||||
print(f"gru A: {gru_a_complexity} GFLOPS")
|
||||
gflops += gru_a_complexity
|
||||
|
||||
# gru b
|
||||
gru_b_rate = fs
|
||||
gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
|
||||
if verbose:
|
||||
print(f"gru B: {gru_b_complexity} GFLOPS")
|
||||
gflops += gru_b_complexity
|
||||
|
||||
|
||||
# dual fcs
|
||||
fc = self.dual_fc
|
||||
rate = fs
|
||||
input_size = fc.dense1.in_features
|
||||
output_size = fc.dense1.out_features
|
||||
dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
|
||||
if self.hsampling:
|
||||
dual_fc_complexity /= 8
|
||||
if verbose:
|
||||
print(f"dual_fc: {dual_fc_complexity} GFLOPS")
|
||||
gflops += dual_fc_complexity
|
||||
|
||||
if verbose:
|
||||
print(f'total: {gflops} GFLOPS')
|
||||
|
||||
return gflops
|
||||
|
||||
def frame_rate_network(self, features, periods):
|
||||
|
||||
embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
|
||||
features = torch.concat((features, embedded_periods), dim=-1)
|
||||
|
||||
# convert to channels first and calculate conditioning vector
|
||||
c = torch.permute(features, [0, 2, 1])
|
||||
|
||||
c = torch.tanh(self.feature_conv1(c))
|
||||
c = torch.tanh(self.feature_conv2(c))
|
||||
# back to channels last
|
||||
c = torch.permute(c, [0, 2, 1])
|
||||
c = torch.tanh(self.feature_dense1(c))
|
||||
c = torch.tanh(self.feature_dense2(c))
|
||||
|
||||
return c
|
||||
|
||||
def sample_rate_network(self, signals, c, gru_states):
|
||||
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
|
||||
c_upsampled = torch.repeat_interleave(c, self.frame_size, dim=1)
|
||||
|
||||
y = torch.concat((embedded_signals, c_upsampled), dim=-1)
|
||||
y, gru_a_state = self.gru_a(y, gru_states[0])
|
||||
y = torch.concat((y, c_upsampled), dim=-1)
|
||||
y, gru_b_state = self.gru_b(y, gru_states[1])
|
||||
|
||||
y = self.dual_fc(y)
|
||||
|
||||
if self.hsampling:
|
||||
y = torch.sigmoid(y)
|
||||
log_probs = torch.log(get_pdf_from_tree(y) + 1e-6)
|
||||
else:
|
||||
log_probs = torch.log_softmax(y, dim=-1)
|
||||
|
||||
return log_probs, (gru_a_state, gru_b_state)
|
||||
|
||||
def decoder(self, signals, c, gru_states):
|
||||
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
|
||||
|
||||
y = torch.concat((embedded_signals, c), dim=-1)
|
||||
y, gru_a_state = self.gru_a(y, gru_states[0])
|
||||
y = torch.concat((y, c), dim=-1)
|
||||
y, gru_b_state = self.gru_b(y, gru_states[1])
|
||||
|
||||
y = self.dual_fc(y)
|
||||
|
||||
if self.hsampling:
|
||||
y = torch.sigmoid(y)
|
||||
probs = get_pdf_from_tree(y)
|
||||
else:
|
||||
probs = torch.softmax(y, dim=-1)
|
||||
|
||||
return probs, (gru_a_state, gru_b_state)
|
||||
|
||||
def forward(self, features, periods, signals, gru_states):
|
||||
|
||||
c = self.frame_rate_network(features, periods)
|
||||
log_probs, _ = self.sample_rate_network(signals, c, gru_states)
|
||||
|
||||
return log_probs
|
||||
|
||||
def generate(self, features, periods, lpcs):
|
||||
|
||||
with torch.no_grad():
|
||||
device = self.parameters().__next__().device
|
||||
|
||||
num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
|
||||
lpc_order = lpcs.shape[-1]
|
||||
num_input_signals = len(self.input_layout['signals'])
|
||||
pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
|
||||
|
||||
# signal buffers
|
||||
pcm = torch.zeros((num_frames * self.frame_size + lpc_order))
|
||||
output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
|
||||
mem = 0
|
||||
|
||||
# state buffers
|
||||
gru_a_state = torch.zeros((1, 1, self.gru_a_units))
|
||||
gru_b_state = torch.zeros((1, 1, self.gru_b_units))
|
||||
gru_states = [gru_a_state, gru_b_state]
|
||||
|
||||
input_signals = torch.zeros((1, 1, num_input_signals), dtype=torch.long) + 128
|
||||
|
||||
# push data to device
|
||||
features = features.to(device)
|
||||
periods = periods.to(device)
|
||||
lpcs = lpcs.to(device)
|
||||
|
||||
# lpc weighting
|
||||
weights = torch.FloatTensor([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]).to(device)
|
||||
lpcs = lpcs * weights
|
||||
|
||||
# run feature encoding
|
||||
c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
|
||||
|
||||
for frame_index in range(num_frames):
|
||||
frame_start = frame_index * self.frame_size
|
||||
pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
|
||||
a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
|
||||
current_c = c[:, frame_index : frame_index + 1, :]
|
||||
|
||||
for i in range(self.frame_size):
|
||||
pcm_position = frame_start + i + lpc_order
|
||||
output_position = frame_start + i
|
||||
|
||||
# prepare input
|
||||
pred = torch.sum(pcm[pcm_position - lpc_order : pcm_position] * a)
|
||||
if 'prediction' in self.input_layout['signals']:
|
||||
input_signals[0, 0, self.input_layout['signals']['prediction']] = lin2ulawq(pred)
|
||||
|
||||
# run single step of sample rate network
|
||||
probs, gru_states = self.decoder(
|
||||
input_signals,
|
||||
current_c,
|
||||
gru_states
|
||||
)
|
||||
|
||||
# sample from output
|
||||
exc_ulaw = sample_excitation(probs, pitch_corr)
|
||||
|
||||
# signal generation
|
||||
exc = ulaw2lin(exc_ulaw)
|
||||
sig = exc + pred
|
||||
pcm[pcm_position] = sig
|
||||
mem = 0.85 * mem + float(sig)
|
||||
output[output_position] = clip_to_int16(round(mem))
|
||||
|
||||
# buffer update
|
||||
if 'last_signal' in self.input_layout['signals']:
|
||||
input_signals[0, 0, self.input_layout['signals']['last_signal']] = lin2ulawq(sig)
|
||||
|
||||
if 'last_error' in self.input_layout['signals']:
|
||||
input_signals[0, 0, self.input_layout['signals']['last_error']] = lin2ulawq(exc)
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,437 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from utils.layers.subconditioner import get_subconditioner
|
||||
from utils.layers import DualFC
|
||||
|
||||
from utils.ulaw import lin2ulawq, ulaw2lin
|
||||
from utils.sample import sample_excitation
|
||||
from utils.pcm import clip_to_int16
|
||||
from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
|
||||
|
||||
from utils.misc import interleave_tensors
|
||||
|
||||
|
||||
|
||||
|
||||
# MultiRateLPCNet
|
||||
class MultiRateLPCNet(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(MultiRateLPCNet, self).__init__()
|
||||
|
||||
# general parameters
|
||||
self.input_layout = config['input_layout']
|
||||
self.feature_history = config['feature_history']
|
||||
self.feature_lookahead = config['feature_lookahead']
|
||||
self.signals = config['signals']
|
||||
|
||||
# frame rate network parameters
|
||||
self.feature_dimension = config['feature_dimension']
|
||||
self.period_embedding_dim = config['period_embedding_dim']
|
||||
self.period_levels = config['period_levels']
|
||||
self.feature_channels = self.feature_dimension + self.period_embedding_dim
|
||||
self.feature_conditioning_dim = config['feature_conditioning_dim']
|
||||
self.feature_conv_kernel_size = config['feature_conv_kernel_size']
|
||||
|
||||
# frame rate network layers
|
||||
self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
|
||||
self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
|
||||
self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
|
||||
self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
|
||||
self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
|
||||
|
||||
# sample rate network parameters
|
||||
self.frame_size = config['frame_size']
|
||||
self.signal_levels = config['signal_levels']
|
||||
self.signal_embedding_dim = config['signal_embedding_dim']
|
||||
self.gru_a_units = config['gru_a_units']
|
||||
self.gru_b_units = config['gru_b_units']
|
||||
self.output_levels = config['output_levels']
|
||||
|
||||
# subconditioning B
|
||||
sub_config = config['subconditioning']['subconditioning_b']
|
||||
self.substeps_b = sub_config['number_of_subsamples']
|
||||
self.subcondition_signals_b = sub_config['signals']
|
||||
self.signals_idx_b = [self.input_layout['signals'][key] for key in sub_config['signals']]
|
||||
method = sub_config['method']
|
||||
kwargs = sub_config['kwargs']
|
||||
if type(kwargs) == type(None):
|
||||
kwargs = dict()
|
||||
|
||||
state_size = self.gru_b_units
|
||||
self.subconditioner_b = get_subconditioner(method,
|
||||
sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
|
||||
state_size, self.signal_levels, len(sub_config['signals']),
|
||||
**sub_config['kwargs'])
|
||||
|
||||
# subconditioning A
|
||||
sub_config = config['subconditioning']['subconditioning_a']
|
||||
self.substeps_a = sub_config['number_of_subsamples']
|
||||
self.subcondition_signals_a = sub_config['signals']
|
||||
self.signals_idx_a = [self.input_layout['signals'][key] for key in sub_config['signals']]
|
||||
method = sub_config['method']
|
||||
kwargs = sub_config['kwargs']
|
||||
if type(kwargs) == type(None):
|
||||
kwargs = dict()
|
||||
|
||||
state_size = self.gru_a_units
|
||||
self.subconditioner_a = get_subconditioner(method,
|
||||
sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
|
||||
state_size, self.signal_levels, self.substeps_b * len(sub_config['signals']),
|
||||
**sub_config['kwargs'])
|
||||
|
||||
|
||||
# wrap up subconditioning, group_size_gru_a holds the number
|
||||
# of timesteps that are grouped as sample input for GRU A
|
||||
# input and group_size_subcondition_a holds the number of samples that are
|
||||
# grouped as input to pre-GRU B subconditioning
|
||||
self.group_size_gru_a = self.substeps_a * self.substeps_b
|
||||
self.group_size_subcondition_a = self.substeps_b
|
||||
self.gru_a_rate_divider = self.group_size_gru_a
|
||||
self.gru_b_rate_divider = self.substeps_b
|
||||
|
||||
# gru sizes
|
||||
self.gru_a_input_dim = self.group_size_gru_a * len(self.signals) * self.signal_embedding_dim + self.feature_conditioning_dim
|
||||
self.gru_b_input_dim = self.subconditioner_a.get_output_dim(0) + self.feature_conditioning_dim
|
||||
self.signals_idx = [self.input_layout['signals'][key] for key in self.signals]
|
||||
|
||||
# sample rate network layers
|
||||
self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
|
||||
self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
|
||||
self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
|
||||
|
||||
# sparsification
|
||||
self.sparsifier = []
|
||||
|
||||
# GRU A
|
||||
if 'gru_a' in config['sparsification']:
|
||||
gru_config = config['sparsification']['gru_a']
|
||||
task_list = [(self.gru_a, gru_config['params'])]
|
||||
self.sparsifier.append(GRUSparsifier(task_list,
|
||||
gru_config['start'],
|
||||
gru_config['stop'],
|
||||
gru_config['interval'],
|
||||
gru_config['exponent'])
|
||||
)
|
||||
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
|
||||
gru_config['params'], drop_input=True)
|
||||
else:
|
||||
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
|
||||
|
||||
# GRU B
|
||||
if 'gru_b' in config['sparsification']:
|
||||
gru_config = config['sparsification']['gru_b']
|
||||
task_list = [(self.gru_b, gru_config['params'])]
|
||||
self.sparsifier.append(GRUSparsifier(task_list,
|
||||
gru_config['start'],
|
||||
gru_config['stop'],
|
||||
gru_config['interval'],
|
||||
gru_config['exponent'])
|
||||
)
|
||||
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
|
||||
gru_config['params'])
|
||||
else:
|
||||
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
|
||||
|
||||
|
||||
|
||||
# dual FCs
|
||||
self.dual_fc = []
|
||||
for i in range(self.substeps_b):
|
||||
dim = self.subconditioner_b.get_output_dim(i)
|
||||
self.dual_fc.append(DualFC(dim, self.output_levels))
|
||||
self.add_module(f"dual_fc_{i}", self.dual_fc[-1])
|
||||
|
||||
def get_gflops(self, fs, verbose=False, hierarchical_sampling=False):
|
||||
gflops = 0
|
||||
|
||||
# frame rate network
|
||||
conditioning_dim = self.feature_conditioning_dim
|
||||
feature_channels = self.feature_channels
|
||||
frame_rate = fs / self.frame_size
|
||||
frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
|
||||
if verbose:
|
||||
print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
|
||||
gflops += frame_rate_network_complexity
|
||||
|
||||
# gru a
|
||||
gru_a_rate = fs / self.group_size_gru_a
|
||||
gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
|
||||
if verbose:
|
||||
print(f"gru A: {gru_a_complexity} GFLOPS")
|
||||
gflops += gru_a_complexity
|
||||
|
||||
# subconditioning a
|
||||
subcond_a_rate = fs / self.substeps_b
|
||||
subconditioning_a_complexity = 1e-9 * self.subconditioner_a.get_average_flops_per_step() * subcond_a_rate
|
||||
if verbose:
|
||||
print(f"subconditioning A: {subconditioning_a_complexity} GFLOPS")
|
||||
gflops += subconditioning_a_complexity
|
||||
|
||||
# gru b
|
||||
gru_b_rate = fs / self.substeps_b
|
||||
gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
|
||||
if verbose:
|
||||
print(f"gru B: {gru_b_complexity} GFLOPS")
|
||||
gflops += gru_b_complexity
|
||||
|
||||
# subconditioning b
|
||||
subcond_b_rate = fs
|
||||
subconditioning_b_complexity = 1e-9 * self.subconditioner_b.get_average_flops_per_step() * subcond_b_rate
|
||||
if verbose:
|
||||
print(f"subconditioning B: {subconditioning_b_complexity} GFLOPS")
|
||||
gflops += subconditioning_b_complexity
|
||||
|
||||
# dual fcs
|
||||
for i, fc in enumerate(self.dual_fc):
|
||||
rate = fs / len(self.dual_fc)
|
||||
input_size = fc.dense1.in_features
|
||||
output_size = fc.dense1.out_features
|
||||
dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
|
||||
if hierarchical_sampling:
|
||||
dual_fc_complexity /= 8
|
||||
if verbose:
|
||||
print(f"dual_fc_{i}: {dual_fc_complexity} GFLOPS")
|
||||
gflops += dual_fc_complexity
|
||||
|
||||
if verbose:
|
||||
print(f'total: {gflops} GFLOPS')
|
||||
|
||||
return gflops
|
||||
|
||||
|
||||
|
||||
def sparsify(self):
|
||||
for sparsifier in self.sparsifier:
|
||||
sparsifier.step()
|
||||
|
||||
def frame_rate_network(self, features, periods):
|
||||
|
||||
embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
|
||||
features = torch.concat((features, embedded_periods), dim=-1)
|
||||
|
||||
# convert to channels first and calculate conditioning vector
|
||||
c = torch.permute(features, [0, 2, 1])
|
||||
|
||||
c = torch.tanh(self.feature_conv1(c))
|
||||
c = torch.tanh(self.feature_conv2(c))
|
||||
# back to channels last
|
||||
c = torch.permute(c, [0, 2, 1])
|
||||
c = torch.tanh(self.feature_dense1(c))
|
||||
c = torch.tanh(self.feature_dense2(c))
|
||||
|
||||
return c
|
||||
|
||||
def prepare_signals(self, signals, group_size, signal_idx):
|
||||
""" extracts, delays and groups signals """
|
||||
|
||||
batch_size, sequence_length, num_signals = signals.shape
|
||||
|
||||
# extract signals according to position
|
||||
signals = torch.cat([signals[:, :, i : i + 1] for i in signal_idx],
|
||||
dim=-1)
|
||||
|
||||
# roll back pcm to account for grouping
|
||||
signals = torch.roll(signals, group_size - 1, -2)
|
||||
|
||||
# reshape
|
||||
signals = torch.reshape(signals,
|
||||
(batch_size, sequence_length // group_size, group_size * len(signal_idx)))
|
||||
|
||||
return signals
|
||||
|
||||
|
||||
def sample_rate_network(self, signals, c, gru_states):
|
||||
|
||||
signals_a = self.prepare_signals(signals, self.group_size_gru_a, self.signals_idx)
|
||||
embedded_signals = torch.flatten(self.signal_embedding(signals_a), 2, 3)
|
||||
# features at GRU A rate
|
||||
c_upsampled_a = torch.repeat_interleave(c, self.frame_size // self.gru_a_rate_divider, dim=1)
|
||||
# features at GRU B rate
|
||||
c_upsampled_b = torch.repeat_interleave(c, self.frame_size // self.gru_b_rate_divider, dim=1)
|
||||
|
||||
y = torch.concat((embedded_signals, c_upsampled_a), dim=-1)
|
||||
y, gru_a_state = self.gru_a(y, gru_states[0])
|
||||
# first round of upsampling and subconditioning
|
||||
c_signals_a = self.prepare_signals(signals, self.group_size_subcondition_a, self.signals_idx_a)
|
||||
y = self.subconditioner_a(y, c_signals_a)
|
||||
y = interleave_tensors(y)
|
||||
|
||||
y = torch.concat((y, c_upsampled_b), dim=-1)
|
||||
y, gru_b_state = self.gru_b(y, gru_states[1])
|
||||
c_signals_b = self.prepare_signals(signals, 1, self.signals_idx_b)
|
||||
y = self.subconditioner_b(y, c_signals_b)
|
||||
|
||||
y = [self.dual_fc[i](y[i]) for i in range(self.substeps_b)]
|
||||
y = interleave_tensors(y)
|
||||
|
||||
return y, (gru_a_state, gru_b_state)
|
||||
|
||||
def decoder(self, signals, c, gru_states):
|
||||
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
|
||||
|
||||
y = torch.concat((embedded_signals, c), dim=-1)
|
||||
y, gru_a_state = self.gru_a(y, gru_states[0])
|
||||
y = torch.concat((y, c), dim=-1)
|
||||
y, gru_b_state = self.gru_b(y, gru_states[1])
|
||||
|
||||
y = self.dual_fc(y)
|
||||
|
||||
return torch.softmax(y, dim=-1), (gru_a_state, gru_b_state)
|
||||
|
||||
def forward(self, features, periods, signals, gru_states):
|
||||
|
||||
c = self.frame_rate_network(features, periods)
|
||||
y, _ = self.sample_rate_network(signals, c, gru_states)
|
||||
log_probs = torch.log_softmax(y, dim=-1)
|
||||
|
||||
return log_probs
|
||||
|
||||
def generate(self, features, periods, lpcs):
|
||||
|
||||
with torch.no_grad():
|
||||
device = self.parameters().__next__().device
|
||||
|
||||
num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
|
||||
lpc_order = lpcs.shape[-1]
|
||||
num_input_signals = len(self.signals)
|
||||
pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
|
||||
|
||||
# signal buffers
|
||||
last_signal = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
|
||||
prediction = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
|
||||
last_error = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
|
||||
output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
|
||||
mem = 0
|
||||
|
||||
# state buffers
|
||||
gru_a_state = torch.zeros((1, 1, self.gru_a_units))
|
||||
gru_b_state = torch.zeros((1, 1, self.gru_b_units))
|
||||
|
||||
input_signals = 128 + torch.zeros(self.group_size_gru_a * num_input_signals, dtype=torch.long)
|
||||
# conditioning signals for subconditioner a
|
||||
c_signals_a = 128 + torch.zeros(self.group_size_subcondition_a * len(self.signals_idx_a), dtype=torch.long)
|
||||
# conditioning signals for subconditioner b
|
||||
c_signals_b = 128 + torch.zeros(len(self.signals_idx_b), dtype=torch.long)
|
||||
|
||||
# signal dict
|
||||
signal_dict = {
|
||||
'prediction' : prediction,
|
||||
'last_error' : last_error,
|
||||
'last_signal' : last_signal
|
||||
}
|
||||
|
||||
# push data to device
|
||||
features = features.to(device)
|
||||
periods = periods.to(device)
|
||||
lpcs = lpcs.to(device)
|
||||
|
||||
# run feature encoding
|
||||
c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
|
||||
|
||||
for frame_index in range(num_frames):
|
||||
frame_start = frame_index * self.frame_size
|
||||
pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
|
||||
a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
|
||||
current_c = c[:, frame_index : frame_index + 1, :]
|
||||
|
||||
for i in range(0, self.frame_size, self.group_size_gru_a):
|
||||
pcm_position = frame_start + i + lpc_order
|
||||
output_position = frame_start + i
|
||||
|
||||
# calculate newest prediction
|
||||
prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
|
||||
|
||||
# prepare input
|
||||
for slot in range(self.group_size_gru_a):
|
||||
k = slot - self.group_size_gru_a + 1
|
||||
for idx, name in enumerate(self.signals):
|
||||
input_signals[idx + slot * num_input_signals] = lin2ulawq(
|
||||
signal_dict[name][pcm_position + k]
|
||||
)
|
||||
|
||||
|
||||
# run GRU A
|
||||
embed_signals = self.signal_embedding(input_signals.reshape((1, 1, -1)))
|
||||
embed_signals = torch.flatten(embed_signals, 2)
|
||||
y = torch.cat((embed_signals, current_c), dim=-1)
|
||||
h_a, gru_a_state = self.gru_a(y, gru_a_state)
|
||||
|
||||
# loop over substeps_a
|
||||
for step_a in range(self.substeps_a):
|
||||
# prepare conditioning input
|
||||
for slot in range(self.group_size_subcondition_a):
|
||||
k = slot - self.group_size_subcondition_a + 1
|
||||
for idx, name in enumerate(self.subcondition_signals_a):
|
||||
c_signals_a[idx + slot * num_input_signals] = lin2ulawq(
|
||||
signal_dict[name][pcm_position + k]
|
||||
)
|
||||
|
||||
# subconditioning
|
||||
h_a = self.subconditioner_a.single_step(step_a, h_a, c_signals_a.reshape((1, 1, -1)))
|
||||
|
||||
# run GRU B
|
||||
y = torch.cat((h_a, current_c), dim=-1)
|
||||
h_b, gru_b_state = self.gru_b(y, gru_b_state)
|
||||
|
||||
# loop over substeps b
|
||||
for step_b in range(self.substeps_b):
|
||||
# prepare subconditioning input
|
||||
for idx, name in enumerate(self.subcondition_signals_b):
|
||||
c_signals_b[idx] = lin2ulawq(
|
||||
signal_dict[name][pcm_position]
|
||||
)
|
||||
|
||||
# subcondition
|
||||
h_b = self.subconditioner_b.single_step(step_b, h_b, c_signals_b.reshape((1, 1, -1)))
|
||||
|
||||
# run dual FC
|
||||
probs = torch.softmax(self.dual_fc[step_b](h_b), dim=-1)
|
||||
|
||||
# sample
|
||||
new_exc = ulaw2lin(sample_excitation(probs, pitch_corr))
|
||||
|
||||
# update signals
|
||||
sig = new_exc + prediction[pcm_position]
|
||||
last_error[pcm_position + 1] = new_exc
|
||||
last_signal[pcm_position + 1] = sig
|
||||
|
||||
mem = 0.85 * mem + float(sig)
|
||||
output[output_position] = clip_to_int16(round(mem))
|
||||
|
||||
# increase positions
|
||||
pcm_position += 1
|
||||
output_position += 1
|
||||
|
||||
# calculate next prediction
|
||||
prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
from models import model_dict
|
||||
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
args = type('dummy', (object,),
|
||||
{
|
||||
'setup' : 'setups/lpcnet_m/setup_1_4_concatenative.yml',
|
||||
'hierarchical_sampling' : False
|
||||
})()
|
||||
else:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||
parser.add_argument('--hierarchical-sampling', action="store_true", help='whether to assume hierarchical sampling (default=False)', default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
# check model
|
||||
if not 'model' in setup['lpcnet']:
|
||||
print(f'warning: did not find model entry in setup, using default lpcnet')
|
||||
model_name = 'lpcnet'
|
||||
else:
|
||||
model_name = setup['lpcnet']['model']
|
||||
|
||||
# create model
|
||||
model = model_dict[model_name](setup['lpcnet']['config'])
|
||||
|
||||
gflops = model.get_gflops(16000, verbose=True, hierarchical_sampling=args.hierarchical_sampling)
|
||||
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from uuid import UUID
|
||||
from collections import OrderedDict
|
||||
import pickle
|
||||
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, help="input folder containing multi-run output")
|
||||
parser.add_argument("tag", type=str, help="tag for multi-run experiment")
|
||||
parser.add_argument("csv", type=str, help="name for output csv")
|
||||
|
||||
|
||||
def is_uuid(val):
|
||||
try:
|
||||
UUID(val)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def collect_results(folder):
|
||||
|
||||
training_folder = os.path.join(folder, 'training')
|
||||
testing_folder = os.path.join(folder, 'testing')
|
||||
|
||||
# validation loss
|
||||
checkpoint = torch.load(os.path.join(training_folder, 'checkpoints', 'checkpoint_finalize_epoch_1.pth'), map_location='cpu')
|
||||
validation_loss = checkpoint['validation_loss']
|
||||
|
||||
# eval_warpq
|
||||
eval_warpq = utils.data.parse_warpq_scores(os.path.join(training_folder, 'out_finalize.txt'))[-1]
|
||||
|
||||
# testing results
|
||||
testing_results = utils.data.collect_test_stats(os.path.join(testing_folder, 'final'))
|
||||
|
||||
results = OrderedDict()
|
||||
results['eval_loss'] = validation_loss
|
||||
results['eval_warpq'] = eval_warpq
|
||||
results['pesq_mean'] = testing_results['pesq'][0]
|
||||
results['warpq_mean'] = testing_results['warpq'][0]
|
||||
results['pitch_error_mean'] = testing_results['pitch_error'][0]
|
||||
results['voicing_error_mean'] = testing_results['voicing_error'][0]
|
||||
|
||||
return results
|
||||
|
||||
def print_csv(path, results, tag, ranks=None, header=True):
|
||||
|
||||
metrics = next(iter(results.values())).keys()
|
||||
if ranks is not None:
|
||||
rank_keys = next(iter(ranks.values())).keys()
|
||||
else:
|
||||
rank_keys = []
|
||||
|
||||
with open(path, 'w') as f:
|
||||
if header:
|
||||
f.write("uuid, tag")
|
||||
|
||||
for metric in metrics:
|
||||
f.write(f", {metric}")
|
||||
|
||||
for rank in rank_keys:
|
||||
f.write(f", {rank}")
|
||||
|
||||
f.write("\n")
|
||||
|
||||
|
||||
for uuid, values in results.items():
|
||||
f.write(f"{uuid}, {tag}")
|
||||
|
||||
for val in values.values():
|
||||
f.write(f", {val:10.8f}")
|
||||
|
||||
for rank in rank_keys:
|
||||
f.write(f", {ranks[uuid][rank]:4d}")
|
||||
|
||||
f.write("\n")
|
||||
|
||||
def get_ranks(results):
|
||||
|
||||
metrics = list(next(iter(results.values())).keys())
|
||||
|
||||
positive = {'pesq_mean', 'mix'}
|
||||
|
||||
ranks = OrderedDict()
|
||||
for key in results.keys():
|
||||
ranks[key] = OrderedDict()
|
||||
|
||||
for metric in metrics:
|
||||
sign = -1 if metric in positive else 1
|
||||
|
||||
x = sorted([(key, value[metric]) for key, value in results.items()], key=lambda x: sign * x[1])
|
||||
x = [y[0] for y in x]
|
||||
|
||||
for key in results.keys():
|
||||
ranks[key]['rank_' + metric] = x.index(key) + 1
|
||||
|
||||
return ranks
|
||||
|
||||
def analyse_metrics(results):
|
||||
metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
|
||||
|
||||
x = []
|
||||
for metric in metrics:
|
||||
x.append([val[metric] for val in results.values()])
|
||||
|
||||
x = np.array(x)
|
||||
|
||||
print(x)
|
||||
|
||||
def add_mix_metric(results):
|
||||
metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
|
||||
|
||||
x = []
|
||||
for metric in metrics:
|
||||
x.append([val[metric] for val in results.values()])
|
||||
|
||||
x = np.array(x).transpose() * np.array([-1, 1, -1, -1, -1])
|
||||
|
||||
z = (x - np.mean(x, axis=0)) / np.std(x, axis=0)
|
||||
|
||||
print(f"covariance matrix for normalized scores of {metrics}:")
|
||||
print(np.cov(z.transpose()))
|
||||
|
||||
score = np.mean(z, axis=1)
|
||||
|
||||
for i, key in enumerate(results.keys()):
|
||||
results[key]['mix'] = score[i].item()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
uuids = sorted([x for x in os.listdir(args.input) if os.path.isdir(os.path.join(args.input, x)) and is_uuid(x)])
|
||||
|
||||
|
||||
results = OrderedDict()
|
||||
|
||||
for uuid in uuids:
|
||||
results[uuid] = collect_results(os.path.join(args.input, uuid))
|
||||
|
||||
|
||||
add_mix_metric(results)
|
||||
|
||||
ranks = get_ranks(results)
|
||||
|
||||
|
||||
|
||||
csv = args.csv if args.csv.endswith('.csv') else args.csv + '.csv'
|
||||
|
||||
print_csv(args.csv, results, args.tag, ranks=ranks)
|
||||
|
||||
|
||||
with open(csv[:-4] + '.pickle', 'wb') as f:
|
||||
pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
case $# in
|
||||
9) SETUP=$1; OUTDIR=$2; NAME=$3; DEVICE=$4; ROUNDS=$5; LPCNEXT=$6; LPCNET=$7; TESTSUITE=$8; TESTITEMS=$9;;
|
||||
*) echo "loop_run.sh setup outdir name device rounds lpcnext_repo lpcnet_repo testsuite_repo testitems"; exit;;
|
||||
esac
|
||||
|
||||
|
||||
PYTHON="/home/ubuntu/opt/miniconda3/envs/torch/bin/python"
|
||||
TESTFEATURES=${LPCNEXT}/testitems/features/all_0_orig_features.f32
|
||||
WARPQREFERENCE=${LPCNEXT}/testitems/wav/all_0_orig.wav
|
||||
METRICS="warpq,pesq,pitch_error,voicing_error"
|
||||
LPCNETDEMO=${LPCNET}/lpcnet_demo
|
||||
|
||||
for ((round = 1; round <= $ROUNDS; round++))
|
||||
do
|
||||
echo
|
||||
echo round $round
|
||||
|
||||
UUID=$(uuidgen)
|
||||
TRAINOUT=${OUTDIR}/${UUID}/training
|
||||
TESTOUT=${OUTDIR}/${UUID}/testing
|
||||
CHECKPOINT=${TRAINOUT}/checkpoints/checkpoint_last.pth
|
||||
FINALCHECKPOINT=${TRAINOUT}/checkpoints/checkpoint_finalize_last.pth
|
||||
|
||||
# run training
|
||||
echo "starting training..."
|
||||
$PYTHON $LPCNEXT/train_lpcnet.py $SETUP $TRAINOUT --device $DEVICE --test-features $TESTFEATURES --warpq-reference $WARPQREFERENCE
|
||||
|
||||
# run finalization
|
||||
echo "starting finalization..."
|
||||
$PYTHON $LPCNEXT/train_lpcnet.py $SETUP $TRAINOUT \
|
||||
--device $DEVICE --test-features $TESTFEATURES \
|
||||
--warpq-reference $WARPQREFERENCE \
|
||||
--finalize --initial-checkpoint $CHECKPOINT
|
||||
|
||||
# create test configs
|
||||
$PYTHON $LPCNEXT/make_test_config.py ${OUTDIR}/${UUID}/testconfig.yml "$NAME $UUID" $CHECKPOINT --lpcnet-demo $LPCNETDEMO
|
||||
$PYTHON $LPCNEXT/make_test_config.py ${OUTDIR}/${UUID}/testconfig_finalize.yml "$NAME $UUID finalized" $FINALCHECKPOINT --lpcnet-demo $LPCNETDEMO
|
||||
|
||||
# run tests
|
||||
echo "starting test 1 (no finalization)..."
|
||||
$PYTHON $TESTSUITE/run_test.py ${OUTDIR}/${UUID}/testconfig.yml \
|
||||
$TESTITEMS ${TESTOUT}/prefinal --num-workers 8 \
|
||||
--num-testitems 400 --metrics $METRICS
|
||||
|
||||
echo "starting test 2 (after finalization)..."
|
||||
$PYTHON $TESTSUITE/run_test.py ${OUTDIR}/${UUID}/testconfig_finalize.yml \
|
||||
$TESTITEMS ${TESTOUT}/final --num-workers 8 \
|
||||
--num-testitems 400 --metrics $METRICS
|
||||
done
|
||||
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
""" script for creating animations from debug data
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
import sys
|
||||
sys.path.append('./')
|
||||
|
||||
from utils.endoscopy import make_animation, read_data
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('folder', type=str, help='endoscopy folder with debug output')
|
||||
parser.add_argument('output', type=str, help='output file (will be auto-extended with .mp4)')
|
||||
|
||||
parser.add_argument('--start-index', type=int, help='index of first sample to be considered', default=0)
|
||||
parser.add_argument('--stop-index', type=int, help='index of last sample to be considered', default=-1)
|
||||
parser.add_argument('--interval', type=int, help='interval between frames in ms', default=20)
|
||||
parser.add_argument('--half-window-length', type=int, help='half size of window for displaying signals', default=80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
filename = args.output if args.output.endswith('.mp4') else args.output + '.mp4'
|
||||
data = read_data(args.folder)
|
||||
|
||||
make_animation(
|
||||
data,
|
||||
filename,
|
||||
start_index=args.start_index,
|
||||
stop_index = args.stop_index,
|
||||
half_signal_window_length=args.half_window_length
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="sets s_t to augmented_s_t")
|
||||
|
||||
parser.add_argument('datafile', type=str, help='data.s16 file path')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
data = np.memmap(args.datafile, dtype='int16', mode='readwrite')
|
||||
|
||||
# signal is in data[1::2]
|
||||
# last augmented signal is in data[0::2]
|
||||
|
||||
data[1 : - 1 : 2] = data[2 : : 2]
|
||||
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
case $# in
|
||||
9) SETUP=$1; OUTDIR=$2; NAME=$3; NUMDEVICES=$4; ROUNDS=$5; LPCNEXT=$6; LPCNET=$7; TESTSUITE=$8; TESTITEMS=$9;;
|
||||
*) echo "multi_run.sh setup outdir name num_devices rounds_per_device lpcnext_repo lpcnet_repo testsuite_repo testitems"; exit;;
|
||||
esac
|
||||
|
||||
|
||||
LOOPRUN=${LPCNEXT}/loop_run.sh
|
||||
|
||||
mkdir -p $OUTDIR
|
||||
|
||||
for ((i = 0; i < $NUMDEVICES; i++))
|
||||
do
|
||||
echo "launching job queue for device $i"
|
||||
nohup bash $LOOPRUN $SETUP $OUTDIR "$NAME" "cuda:$i" $ROUNDS $LPCNEXT $LPCNET $TESTSUITE $TESTITEMS > $OUTDIR/job_${i}_out.txt &
|
||||
done
|
||||
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
case $# in
|
||||
3) FEATURES=$1; FOLDER=$2; PYTHON=$3;;
|
||||
*) echo "run_inference_test.sh <features file> <output folder> <python path>"; exit;;
|
||||
esac
|
||||
|
||||
|
||||
SCRIPTFOLDER=$(dirname "$0")
|
||||
|
||||
mkdir -p $FOLDER/inference_test
|
||||
|
||||
# update checkpoints
|
||||
for fn in $(find $FOLDER -type f -name "checkpoint*.pth")
|
||||
do
|
||||
tmp=$(basename $fn)
|
||||
tmp=${tmp%.pth}
|
||||
epoch=${tmp#checkpoint_epoch_}
|
||||
echo "running inference with checkpoint $fn..."
|
||||
$PYTHON $SCRIPTFOLDER/../test_lpcnet.py $FEATURES $fn $FOLDER/inference_test/output_epoch_${epoch}.wav
|
||||
done
|
||||
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" script for updating checkpoints with new setup entries
|
||||
|
||||
Use this script to update older outputs with newly introduced
|
||||
parameters. (Saves us the trouble of backward compatibility)
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('checkpoint_file', type=str, help='checkpoint to be updated')
|
||||
parser.add_argument('--model', type=str, help='model update', default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_file, map_location='cpu')
|
||||
|
||||
# update model entry
|
||||
if type(args.model) != type(None):
|
||||
checkpoint['setup']['lpcnet']['model'] = args.model
|
||||
|
||||
torch.save(checkpoint, args.checkpoint_file)
|
||||
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
case $# in
|
||||
3) FOLDER=$1; MODEL=$2; PYTHON=$3;;
|
||||
*) echo "update_output_folder.sh folder model python"; exit;;
|
||||
esac
|
||||
|
||||
|
||||
SCRIPTFOLDER=$(dirname "$0")
|
||||
|
||||
|
||||
# update setup
|
||||
echo "updating $FOLDER/setup.py..."
|
||||
$PYTHON $SCRIPTFOLDER/update_setups.py $FOLDER/setup.yml --model $MODEL
|
||||
|
||||
# update checkpoints
|
||||
for fn in $(find $FOLDER -type f -name "checkpoint*.pth")
|
||||
do
|
||||
echo "updating $fn..."
|
||||
$PYTHON $SCRIPTFOLDER/update_checkpoints.py $fn --model $MODEL
|
||||
done
|
||||
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" script for updating setup files with new setup entries
|
||||
|
||||
Use this script to update older outputs with newly introduced
|
||||
parameters. (Saves us the trouble of backward compatibility)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('setup_file', type=str, help='setup to be updated')
|
||||
parser.add_argument('--model', type=str, help='model update', default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# load setup
|
||||
with open(args.setup_file, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
# update model entry
|
||||
if type(args.model) != type(None):
|
||||
setup['lpcnet']['model'] = args.model
|
||||
|
||||
# dump result
|
||||
with open(args.setup_file, 'w') as f:
|
||||
yaml.dump(setup, f)
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
from models import model_dict
|
||||
from utils.data import load_features
|
||||
from utils.wav import wavwrite16
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
args = type('dummy', (object,),
|
||||
{
|
||||
'features' : 'features.f32',
|
||||
'checkpoint' : 'checkpoint.pth',
|
||||
'output' : 'out.wav',
|
||||
'version' : 2
|
||||
})()
|
||||
else:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('features', type=str, help='feature file')
|
||||
parser.add_argument('checkpoint', type=str, help='checkpoint file')
|
||||
parser.add_argument('output', type=str, help='output file')
|
||||
parser.add_argument('--version', type=int, help='feature version', default=2)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(2)
|
||||
|
||||
version = args.version
|
||||
feature_file = args.features
|
||||
checkpoint_file = args.checkpoint
|
||||
|
||||
|
||||
|
||||
output_file = args.output
|
||||
if not output_file.endswith('.wav'):
|
||||
output_file += '.wav'
|
||||
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
||||
|
||||
# check model
|
||||
if not 'model' in checkpoint['setup']['lpcnet']:
|
||||
print(f'warning: did not find model entry in setup, using default lpcnet')
|
||||
model_name = 'lpcnet'
|
||||
else:
|
||||
model_name = checkpoint['setup']['lpcnet']['model']
|
||||
|
||||
model = model_dict[model_name](checkpoint['setup']['lpcnet']['config'])
|
||||
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
data = load_features(feature_file)
|
||||
|
||||
output = model.generate(data['features'], data['periods'], data['lpcs'])
|
||||
|
||||
wavwrite16(output_file, output.numpy(), 16000)
|
||||
272
managed_components/78__esp-opus/dnn/torch/lpcnet/train_lpcnet.py
Normal file
272
managed_components/78__esp-opus/dnn/torch/lpcnet/train_lpcnet.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
try:
|
||||
import git
|
||||
has_git = True
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from data import LPCNetDataset
|
||||
from models import model_dict
|
||||
from engine.lpcnet_engine import train_one_epoch, evaluate
|
||||
from utils.data import load_features
|
||||
from utils.wav import wavwrite16
|
||||
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
args = type('dummy', (object,),
|
||||
{
|
||||
'setup' : 'setup.yml',
|
||||
'output' : 'testout',
|
||||
'device' : None,
|
||||
'test_features' : None,
|
||||
'finalize': False,
|
||||
'initial_checkpoint': None,
|
||||
'no-redirect': False
|
||||
})()
|
||||
else:
|
||||
parser = argparse.ArgumentParser("train_lpcnet.py")
|
||||
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||
parser.add_argument('output', type=str, help='output path')
|
||||
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||
parser.add_argument('--test-features', type=str, help='test feature file in v2 format', default=None)
|
||||
parser.add_argument('--finalize', action='store_true', help='run single training round with lr=1e-5')
|
||||
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of output')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
if args.finalize:
|
||||
if args.initial_checkpoint is None:
|
||||
raise ValueError('finalization requires initial checkpoint')
|
||||
|
||||
if 'sparsification' in setup['lpcnet']['config']:
|
||||
for sp_job in setup['lpcnet']['config']['sparsification'].values():
|
||||
sp_job['start'], sp_job['stop'] = 0, 0
|
||||
|
||||
setup['training']['lr'] = 1.0e-5
|
||||
setup['training']['lr_decay_factor'] = 0.0
|
||||
setup['training']['epochs'] = 1
|
||||
|
||||
checkpoint_prefix = 'checkpoint_finalize'
|
||||
output_prefix = 'output_finalize'
|
||||
setup_name = 'setup_finalize.yml'
|
||||
output_file='out_finalize.txt'
|
||||
else:
|
||||
checkpoint_prefix = 'checkpoint'
|
||||
output_prefix = 'output'
|
||||
setup_name = 'setup.yml'
|
||||
output_file='out.txt'
|
||||
|
||||
|
||||
# check model
|
||||
if not 'model' in setup['lpcnet']:
|
||||
print(f'warning: did not find model entry in setup, using default lpcnet')
|
||||
model_name = 'lpcnet'
|
||||
else:
|
||||
model_name = setup['lpcnet']['model']
|
||||
|
||||
# prepare output folder
|
||||
if os.path.exists(args.output) and not debug and not args.finalize:
|
||||
print("warning: output folder exists")
|
||||
|
||||
reply = input('continue? (y/n): ')
|
||||
while reply not in {'y', 'n'}:
|
||||
reply = input('continue? (y/n): ')
|
||||
|
||||
if reply == 'n':
|
||||
os._exit()
|
||||
else:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
|
||||
# add repo info to setup
|
||||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
is_dirty = repo.is_dirty()
|
||||
|
||||
if is_dirty:
|
||||
print("warning: repo is dirty")
|
||||
|
||||
setup['repo']['hash'] = hash
|
||||
setup['repo']['urls'] = urls
|
||||
setup['repo']['dirty'] = is_dirty
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
# dump setup
|
||||
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||
yaml.dump(setup, f)
|
||||
|
||||
# prepare inference test if wanted
|
||||
run_inference_test = False
|
||||
if type(args.test_features) != type(None):
|
||||
test_features = load_features(args.test_features)
|
||||
inference_test_dir = os.path.join(args.output, 'inference_test')
|
||||
os.makedirs(inference_test_dir, exist_ok=True)
|
||||
run_inference_test = True
|
||||
|
||||
# training parameters
|
||||
batch_size = setup['training']['batch_size']
|
||||
epochs = setup['training']['epochs']
|
||||
lr = setup['training']['lr']
|
||||
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||
|
||||
# load training dataset
|
||||
lpcnet_config = setup['lpcnet']['config']
|
||||
data = LPCNetDataset( setup['dataset'],
|
||||
features=lpcnet_config['features'],
|
||||
input_signals=lpcnet_config['signals'],
|
||||
target=lpcnet_config['target'],
|
||||
frames_per_sample=setup['training']['frames_per_sample'],
|
||||
feature_history=lpcnet_config['feature_history'],
|
||||
feature_lookahead=lpcnet_config['feature_lookahead'],
|
||||
lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
|
||||
|
||||
# load validation dataset if given
|
||||
if 'validation_dataset' in setup:
|
||||
validation_data = LPCNetDataset( setup['validation_dataset'],
|
||||
features=lpcnet_config['features'],
|
||||
input_signals=lpcnet_config['signals'],
|
||||
target=lpcnet_config['target'],
|
||||
frames_per_sample=setup['training']['frames_per_sample'],
|
||||
feature_history=lpcnet_config['feature_history'],
|
||||
feature_lookahead=lpcnet_config['feature_lookahead'],
|
||||
lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
|
||||
|
||||
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
|
||||
|
||||
run_validation = True
|
||||
else:
|
||||
run_validation = False
|
||||
|
||||
# create model
|
||||
model = model_dict[model_name](setup['lpcnet']['config'])
|
||||
|
||||
if args.initial_checkpoint is not None:
|
||||
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||
model.load_state_dict(chkpt['state_dict'])
|
||||
|
||||
# set compute device
|
||||
if type(args.device) == type(None):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
# push model to device
|
||||
model.to(device)
|
||||
|
||||
# dataloader
|
||||
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
|
||||
|
||||
# optimizer is introduced to trainable parameters
|
||||
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.Adam(parameters, lr=lr)
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||
|
||||
# loss
|
||||
criterion = torch.nn.NLLLoss()
|
||||
|
||||
# model checkpoint
|
||||
checkpoint = {
|
||||
'setup' : setup,
|
||||
'state_dict' : model.state_dict(),
|
||||
'loss' : -1
|
||||
}
|
||||
|
||||
if not args.no_redirect:
|
||||
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||
|
||||
best_loss = 1e9
|
||||
|
||||
for ep in range(1, epochs + 1):
|
||||
print(f"training epoch {ep}...")
|
||||
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
|
||||
|
||||
|
||||
# save checkpoint
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['loss'] = new_loss
|
||||
|
||||
if run_validation:
|
||||
print("running validation...")
|
||||
validation_loss = evaluate(model, criterion, validation_dataloader, device)
|
||||
checkpoint['validation_loss'] = validation_loss
|
||||
|
||||
if validation_loss < best_loss:
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
|
||||
best_loss = validation_loss
|
||||
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||
|
||||
# run inference test
|
||||
if run_inference_test:
|
||||
model.to("cpu")
|
||||
print("running inference test...")
|
||||
|
||||
output = model.generate(test_features['features'], test_features['periods'], test_features['lpcs'])
|
||||
|
||||
testfilename = os.path.join(inference_test_dir, output_prefix + f'_epoch_{ep}.wav')
|
||||
|
||||
wavwrite16(testfilename, output.numpy(), 16000)
|
||||
|
||||
model.to(device)
|
||||
|
||||
print()
|
||||
@@ -0,0 +1,4 @@
|
||||
from . import sparsification
|
||||
from . import data
|
||||
from . import pcm
|
||||
from . import sample
|
||||
141
managed_components/78__esp-opus/dnn/torch/lpcnet/utils/data.py
Normal file
141
managed_components/78__esp-opus/dnn/torch/lpcnet/utils/data.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def load_features(feature_file, version=2):
|
||||
if version == 2:
|
||||
layout = {
|
||||
'cepstrum': [0,18],
|
||||
'periods': [18, 19],
|
||||
'pitch_corr': [19, 20],
|
||||
'lpc': [20, 36]
|
||||
}
|
||||
frame_length = 36
|
||||
|
||||
elif version == 1:
|
||||
layout = {
|
||||
'cepstrum': [0,18],
|
||||
'periods': [36, 37],
|
||||
'pitch_corr': [37, 38],
|
||||
'lpc': [39, 55],
|
||||
}
|
||||
frame_length = 55
|
||||
else:
|
||||
raise ValueError(f'unknown feature version: {version}')
|
||||
|
||||
|
||||
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
|
||||
raw_features = raw_features.reshape((-1, frame_length))
|
||||
|
||||
features = torch.cat(
|
||||
[
|
||||
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
|
||||
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
|
||||
],
|
||||
dim=1
|
||||
)
|
||||
|
||||
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
|
||||
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
|
||||
|
||||
|
||||
|
||||
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
|
||||
ref_data = np.memmap(reference_data_path, dtype=np.int16)
|
||||
signal = np.memmap(signal_path, dtype=np.int16)
|
||||
|
||||
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
|
||||
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
|
||||
|
||||
|
||||
assert len(signal) % 160 == 0
|
||||
num_frames = len(signal) // 160
|
||||
mem = np.zeros(1)
|
||||
for fr in range(len(signal)//160):
|
||||
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
|
||||
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
|
||||
|
||||
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
|
||||
|
||||
new_data[:] = 0
|
||||
N = len(signal) - offset
|
||||
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
|
||||
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
|
||||
|
||||
|
||||
def parse_warpq_scores(output_file):
|
||||
""" extracts warpq scores from output file """
|
||||
|
||||
with open(output_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def parse_stats_file(file):
|
||||
|
||||
with open(file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
mean = float(lines[0].split(":")[-1])
|
||||
bt_mean = float(lines[1].split(":")[-1])
|
||||
top_mean = float(lines[2].split(":")[-1])
|
||||
|
||||
return mean, bt_mean, top_mean
|
||||
|
||||
def collect_test_stats(test_folder):
|
||||
""" collects statistics for all discovered metrics from test folder """
|
||||
|
||||
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
|
||||
|
||||
results = dict()
|
||||
|
||||
content = os.listdir(test_folder)
|
||||
|
||||
stats_files = [file for file in content if file.startswith('stats_')]
|
||||
|
||||
for file in stats_files:
|
||||
metric = file[len("stats_") : -len(".txt")]
|
||||
|
||||
if metric not in metrics:
|
||||
print(f"warning: unknown metric {metric}")
|
||||
|
||||
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
|
||||
|
||||
results[metric] = [mean, bt_mean, top_mean]
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" module for inspecting models during inference """
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation as animation
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
|
||||
_state = dict()
|
||||
_folder = 'endoscopy'
|
||||
|
||||
def get_gru_gates(gru, input, state):
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
|
||||
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
|
||||
|
||||
# reset gate
|
||||
start, stop = 0 * hidden_size, 1 * hidden_size
|
||||
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# update gate
|
||||
start, stop = 1 * hidden_size, 2 * hidden_size
|
||||
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# new gate
|
||||
start, stop = 2 * hidden_size, 3 * hidden_size
|
||||
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
|
||||
|
||||
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
|
||||
|
||||
|
||||
def init(folder='endoscopy'):
|
||||
""" sets up output folder for endoscopy data """
|
||||
|
||||
global _folder
|
||||
_folder = folder
|
||||
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
else:
|
||||
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
|
||||
|
||||
def write_data(key, data, fs):
|
||||
""" appends data to previous data written under key """
|
||||
|
||||
global _state
|
||||
|
||||
# convert to numpy if torch.Tensor is given
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.detach().numpy()
|
||||
|
||||
if not key in _state:
|
||||
_state[key] = {
|
||||
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
|
||||
'fs' : fs,
|
||||
'dim' : tuple(data.shape),
|
||||
'dtype' : str(data.dtype)
|
||||
}
|
||||
|
||||
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
|
||||
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
|
||||
else:
|
||||
if _state[key]['fs'] != fs:
|
||||
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
|
||||
if _state[key]['dtype'] != str(data.dtype):
|
||||
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
|
||||
if _state[key]['dim'] != tuple(data.shape):
|
||||
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
|
||||
|
||||
_state[key]['fid'].write(data.tobytes())
|
||||
|
||||
def close(folder='endoscopy'):
|
||||
""" clean up """
|
||||
for key in _state.keys():
|
||||
_state[key]['fid'].close()
|
||||
|
||||
|
||||
def read_data(folder='endoscopy'):
|
||||
""" retrieves written data as numpy arrays """
|
||||
|
||||
|
||||
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
|
||||
|
||||
return_dict = dict()
|
||||
|
||||
for key in keys:
|
||||
with open(os.path.join(folder, key + '.yml'), 'r') as f:
|
||||
value = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
|
||||
data = np.frombuffer(f.read(), dtype=value['dtype'])
|
||||
|
||||
value['data'] = data.reshape((-1,) + value['dim'])
|
||||
|
||||
return_dict[key] = value
|
||||
|
||||
return return_dict
|
||||
|
||||
def get_best_reshape(shape, target_ratio=1):
|
||||
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
|
||||
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return (1,)
|
||||
|
||||
num_columns = int((pixel_count / target_ratio)**.5)
|
||||
|
||||
while (pixel_count % num_columns):
|
||||
num_columns -= 1
|
||||
|
||||
num_rows = pixel_count // num_columns
|
||||
|
||||
return (num_rows, num_columns)
|
||||
|
||||
def get_type_and_shape(shape):
|
||||
|
||||
# can happen if data is one dimensional
|
||||
if len(shape) == 0:
|
||||
shape = (1,)
|
||||
|
||||
# calculate pixel count
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return 'plot', (1, )
|
||||
|
||||
# stay with shape if already 2-dimensional
|
||||
if len(shape) == 2:
|
||||
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
|
||||
return 'image', shape
|
||||
|
||||
return 'image', get_best_reshape(shape)
|
||||
|
||||
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
|
||||
|
||||
# determine plot setup
|
||||
num_keys = len(data.keys())
|
||||
|
||||
num_rows = int((num_keys * 3/4) ** .5)
|
||||
|
||||
num_cols = (num_keys + num_rows - 1) // num_rows
|
||||
|
||||
fig, axs = plt.subplots(num_rows, num_cols)
|
||||
fig.set_size_inches(num_cols * 5, num_rows * 5)
|
||||
|
||||
display = dict()
|
||||
|
||||
fs_max = max([val['fs'] for val in data.values()])
|
||||
|
||||
num_samples = max([val['data'].shape[0] for val in data.values()])
|
||||
|
||||
keys = sorted(data.keys())
|
||||
|
||||
# inspect data
|
||||
for i, key in enumerate(keys):
|
||||
axs[i // num_cols, i % num_cols].title.set_text(key)
|
||||
|
||||
display[key] = dict()
|
||||
|
||||
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
|
||||
display[key]['down_factor'] = data[key]['fs'] / fs_max
|
||||
|
||||
start_index = max(start_index, half_signal_window_length)
|
||||
while stop_index < 0:
|
||||
stop_index += num_samples
|
||||
|
||||
stop_index = min(stop_index, num_samples - half_signal_window_length)
|
||||
|
||||
# actual plotting
|
||||
frames = []
|
||||
for index in range(start_index, stop_index):
|
||||
ims = []
|
||||
for i, key in enumerate(keys):
|
||||
feature_index = int(round(index * display[key]['down_factor']))
|
||||
|
||||
if display[key]['type'] == 'plot':
|
||||
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
|
||||
|
||||
elif display[key]['type'] == 'image':
|
||||
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
|
||||
|
||||
frames.append(ims)
|
||||
|
||||
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
|
||||
|
||||
if not filename.endswith('.mp4'):
|
||||
filename += '.mp4'
|
||||
|
||||
ani.save(filename)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .dual_fc import DualFC
|
||||
from .subconditioner import AdditiveSubconditioner, ModulativeSubconditioner, ConcatenativeSubconditioner
|
||||
from .pcm_embeddings import PCMEmbedding, DifferentiablePCMEmbedding
|
||||
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class DualFC(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(DualFC, self).__init__()
|
||||
|
||||
self.dense1 = nn.Linear(input_dim, output_dim)
|
||||
self.dense2 = nn.Linear(input_dim, output_dim)
|
||||
|
||||
self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
|
||||
self.beta = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.alpha * torch.tanh(self.dense1(x)) + self.beta * torch.tanh(self.dense2(x))
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" module implementing PCM embeddings for LPCNet """
|
||||
|
||||
import math as m
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PCMEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim=128, num_levels=256):
|
||||
super(PCMEmbedding, self).__init__()
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.num_levels = num_levels
|
||||
|
||||
self.embedding = nn.Embedding(self.num_levels, self.num_dim)
|
||||
|
||||
# initialize
|
||||
with torch.no_grad():
|
||||
num_rows, num_cols = self.num_levels, self.embed_dim
|
||||
a = m.sqrt(12) * (torch.rand(num_rows, num_cols) - 0.5)
|
||||
for i in range(num_rows):
|
||||
a[i, :] += m.sqrt(12) * (i - num_rows / 2)
|
||||
self.embedding.weight[:, :] = 0.1 * a
|
||||
|
||||
def forward(self, x):
|
||||
return self.embeddint(x)
|
||||
|
||||
|
||||
class DifferentiablePCMEmbedding(PCMEmbedding):
|
||||
def __init__(self, embed_dim, num_levels=256):
|
||||
super(DifferentiablePCMEmbedding, self).__init__(embed_dim, num_levels)
|
||||
|
||||
def forward(self, x):
|
||||
x_int = (x - torch.floor(x)).detach().long()
|
||||
x_frac = x - x_int
|
||||
x_next = torch.minimum(x_int + 1, self.num_levels)
|
||||
|
||||
embed_0 = self.embedding(x_int)
|
||||
embed_1 = self.embedding(x_next)
|
||||
|
||||
return (1 - x_frac) * embed_0 + x_frac * embed_1
|
||||
@@ -0,0 +1,497 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
from re import sub
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
||||
|
||||
def get_subconditioner( method,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
**kwargs):
|
||||
|
||||
subconditioner_dict = {
|
||||
'additive' : AdditiveSubconditioner,
|
||||
'concatenative' : ConcatenativeSubconditioner,
|
||||
'modulative' : ModulativeSubconditioner
|
||||
}
|
||||
|
||||
return subconditioner_dict[method](number_of_subsamples,
|
||||
pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
|
||||
|
||||
|
||||
class Subconditioner(nn.Module):
|
||||
def __init__(self):
|
||||
""" upsampling by subconditioning
|
||||
|
||||
Upsamples a sequence of states conditioning on pcm signals and
|
||||
optionally a feature vector.
|
||||
"""
|
||||
super(Subconditioner, self).__init__()
|
||||
|
||||
def forward(self, states, signals, features=None):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
def single_step(self, index, state, signals, features):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
def get_output_dim(self, index):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
|
||||
class AdditiveSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
**kwargs):
|
||||
""" subconditioning by addition """
|
||||
|
||||
super(AdditiveSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
|
||||
if self.pcm_embedding_size != self.state_size:
|
||||
raise ValueError('For additive subconditioning state and embedding '
|
||||
+ f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
|
||||
|
||||
self.embeddings = [None]
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.sum(embed, dim=2)
|
||||
|
||||
new_states = new_states + embed
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.sum(embed_signals, dim=-2)
|
||||
c_state = state + c
|
||||
|
||||
return c_state
|
||||
|
||||
def get_output_dim(self, index):
|
||||
return self.state_size
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
s = self.number_of_subsamples
|
||||
flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
|
||||
return flops
|
||||
|
||||
|
||||
class ConcatenativeSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
recurrent=True,
|
||||
**kwargs):
|
||||
""" subconditioning by concatenation """
|
||||
|
||||
super(ConcatenativeSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
self.recurrent = recurrent
|
||||
|
||||
self.embeddings = []
|
||||
start_index = 0
|
||||
if self.recurrent:
|
||||
start_index = 1
|
||||
self.embeddings.append(None)
|
||||
|
||||
for i in range(start_index, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
s = self.number_of_subsamples
|
||||
|
||||
if self.recurrent:
|
||||
c_states = [states]
|
||||
start = 1
|
||||
else:
|
||||
c_states = []
|
||||
start = 0
|
||||
|
||||
new_states = states
|
||||
for i in range(start, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
if self.recurrent:
|
||||
new_states = torch.cat((new_states, embed), dim=-1)
|
||||
else:
|
||||
new_states = torch.cat((states, embed), dim=-1)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0 and self.recurrent:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.flatten(embed_signals, -2)
|
||||
if not self.recurrent and index > 0:
|
||||
# overwrite previous conditioning vector
|
||||
c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
|
||||
else:
|
||||
c_state = torch.cat((state, c), dim=-1)
|
||||
return c_state
|
||||
|
||||
return c_state
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
return 0
|
||||
|
||||
def get_output_dim(self, index):
|
||||
if self.recurrent:
|
||||
return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
|
||||
else:
|
||||
return self.state_size + self.pcm_embedding_size * self.number_of_signals
|
||||
|
||||
class ModulativeSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
state_recurrent=False,
|
||||
**kwargs):
|
||||
""" subconditioning by modulation """
|
||||
|
||||
super(ModulativeSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
self.state_recurrent = state_recurrent
|
||||
|
||||
self.hidden_size = self.pcm_embedding_size * self.number_of_signals
|
||||
|
||||
if self.state_recurrent:
|
||||
self.hidden_size += self.pcm_embedding_size
|
||||
self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
|
||||
|
||||
self.embeddings = [None]
|
||||
self.alphas = [None]
|
||||
self.betas = [None]
|
||||
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
|
||||
self.add_module('alpha_dense_' + str(i), self.alphas[-1])
|
||||
|
||||
self.betas.append(nn.Linear(self.hidden_size, self.state_size))
|
||||
self.add_module('beta_dense_' + str(i), self.betas[-1])
|
||||
|
||||
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
if self.state_recurrent:
|
||||
comp_states = self.state_transform(new_states)
|
||||
embed = torch.cat((embed, comp_states), dim=-1)
|
||||
|
||||
alpha = torch.tanh(self.alphas[i](embed))
|
||||
beta = torch.tanh(self.betas[i](embed))
|
||||
|
||||
# new state obtained by modulating previous state
|
||||
new_states = torch.tanh((1 + alpha) * new_states + beta)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.flatten(embed_signals, -2)
|
||||
if self.state_recurrent:
|
||||
r_state = self.state_transform(state)
|
||||
c = torch.cat((c, r_state), dim=-1)
|
||||
alpha = torch.tanh(self.alphas[index](c))
|
||||
beta = torch.tanh(self.betas[index](c))
|
||||
c_state = torch.tanh((1 + alpha) * state + beta)
|
||||
return c_state
|
||||
|
||||
return c_state
|
||||
|
||||
def get_output_dim(self, index):
|
||||
return self.state_size
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
s = self.number_of_subsamples
|
||||
|
||||
# estimate activation by 10 flops
|
||||
# c_state = torch.tanh((1 + alpha) * state + beta)
|
||||
flops = 13 * self.state_size
|
||||
|
||||
# hidden size
|
||||
hidden_size = self.number_of_signals * self.pcm_embedding_size
|
||||
if self.state_recurrent:
|
||||
hidden_size += self.pcm_embedding_size
|
||||
|
||||
# counting 2 * A * B flops for Linear(A, B)
|
||||
# alpha = torch.tanh(self.alphas[index](c))
|
||||
# beta = torch.tanh(self.betas[index](c))
|
||||
flops += 4 * hidden_size * self.state_size + 20 * self.state_size
|
||||
|
||||
# r_state = self.state_transform(state)
|
||||
if self.state_recurrent:
|
||||
flops += 2 * self.state_size * self.pcm_embedding_size
|
||||
|
||||
# average over steps
|
||||
flops *= (s - 1) / s
|
||||
|
||||
return flops
|
||||
|
||||
class ComparitiveSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
error_index=-1,
|
||||
apply_gate=True,
|
||||
normalize=False):
|
||||
""" subconditioning by comparison """
|
||||
|
||||
super(ComparitiveSubconditioner, self).__init__()
|
||||
|
||||
self.comparison_size = self.pcm_embedding_size
|
||||
self.error_position = error_index
|
||||
self.apply_gate = apply_gate
|
||||
self.normalize = normalize
|
||||
|
||||
self.state_transform = nn.Linear(self.state_size, self.comparison_size)
|
||||
|
||||
self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
|
||||
self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
|
||||
|
||||
if self.apply_gate:
|
||||
self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
|
||||
|
||||
# embeddings and state transforms
|
||||
self.embeddings = [None]
|
||||
self.alpha_denses = [None]
|
||||
self.beta_denses = [None]
|
||||
self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
|
||||
self.add_module('state_transform_0', self.state_transforms[0])
|
||||
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
state_transform = nn.Linear(self.state_size, self.comparison_size)
|
||||
self.add_module('state_transform_' + str(i), state_transform)
|
||||
self.state_transforms.append(state_transform)
|
||||
|
||||
self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
|
||||
self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
|
||||
|
||||
self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
|
||||
self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
|
||||
|
||||
def forward(self, states, signals):
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
comp_states = self.state_transforms[i](new_states)
|
||||
|
||||
alpha = torch.tanh(self.alpha_dense(embed))
|
||||
beta = torch.tanh(self.beta_dense(embed))
|
||||
|
||||
# new state obtained by modulating previous state
|
||||
new_states = torch.tanh((1 + alpha) * comp_states + beta)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def find(a, v):
|
||||
try:
|
||||
idx = a.index(v)
|
||||
except:
|
||||
idx = -1
|
||||
return idx
|
||||
|
||||
def interleave_tensors(tensors, dim=-2):
|
||||
""" interleave list of tensors along sequence dimension """
|
||||
|
||||
x = torch.cat([x.unsqueeze(dim) for x in tensors], dim=dim)
|
||||
x = torch.flatten(x, dim - 1, dim)
|
||||
|
||||
return x
|
||||
|
||||
def _interleave(x, pcm_levels=256):
|
||||
|
||||
repeats = pcm_levels // (2*x.size(-1))
|
||||
x = x.unsqueeze(-1)
|
||||
p = torch.flatten(torch.repeat_interleave(torch.cat((x, 1 - x), dim=-1), repeats, dim=-1), -2)
|
||||
|
||||
return p
|
||||
|
||||
def get_pdf_from_tree(x):
|
||||
pcm_levels = x.size(-1)
|
||||
|
||||
p = _interleave(x[..., 1:2])
|
||||
n = 4
|
||||
while n <= pcm_levels:
|
||||
p = p * _interleave(x[..., n//2:n])
|
||||
n *= 2
|
||||
|
||||
return p
|
||||
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
def clip_to_int16(x):
|
||||
int_min = -2**15
|
||||
int_max = 2**15 - 1
|
||||
x_clipped = max(int_min, min(x, int_max))
|
||||
return x_clipped
|
||||
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def sample_excitation(probs, pitch_corr):
|
||||
|
||||
norm = lambda x : x / (x.sum() + 1e-18)
|
||||
|
||||
# lowering the temperature
|
||||
probs = norm(probs ** (1 + max(0, 1.5 * pitch_corr - 0.5)))
|
||||
# cut-off tails
|
||||
probs = norm(torch.maximum(probs - 0.002 , torch.FloatTensor([0])))
|
||||
# sample
|
||||
exc = torch.multinomial(probs.squeeze(), 1)
|
||||
|
||||
return exc
|
||||
@@ -0,0 +1,2 @@
|
||||
from .gru_sparsifier import GRUSparsifier
|
||||
from .common import sparsify_matrix, calculate_gru_flops_per_step
|
||||
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
|
||||
""" sparsifies matrix with specified block size
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
matrix : torch.tensor
|
||||
matrix to sparsify
|
||||
density : int
|
||||
target density
|
||||
block_size : [int, int]
|
||||
block size dimensions
|
||||
keep_diagonal : bool
|
||||
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
|
||||
"""
|
||||
|
||||
m, n = matrix.shape
|
||||
m1, n1 = block_size
|
||||
|
||||
if m % m1 or n % n1:
|
||||
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
|
||||
|
||||
# extract diagonal if keep_diagonal = True
|
||||
if keep_diagonal:
|
||||
if m != n:
|
||||
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
|
||||
|
||||
to_spare = torch.diag(torch.diag(matrix))
|
||||
matrix = matrix - to_spare
|
||||
else:
|
||||
to_spare = torch.zeros_like(matrix)
|
||||
|
||||
# calculate energy in sub-blocks
|
||||
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
|
||||
x = x ** 2
|
||||
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
|
||||
|
||||
number_of_blocks = (m * n) // (m1 * n1)
|
||||
number_of_survivors = round(number_of_blocks * density)
|
||||
|
||||
# masking threshold
|
||||
if number_of_survivors == 0:
|
||||
threshold = 0
|
||||
else:
|
||||
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
|
||||
|
||||
# create mask
|
||||
mask = torch.ones_like(block_energies)
|
||||
mask[block_energies < threshold] = 0
|
||||
mask = torch.repeat_interleave(mask, m1, dim=0)
|
||||
mask = torch.repeat_interleave(mask, n1, dim=1)
|
||||
|
||||
# perform masking
|
||||
masked_matrix = mask * matrix + to_spare
|
||||
|
||||
if return_mask:
|
||||
return masked_matrix, mask
|
||||
else:
|
||||
return masked_matrix
|
||||
|
||||
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
|
||||
input_size = gru.input_size
|
||||
hidden_size = gru.hidden_size
|
||||
flops = 0
|
||||
|
||||
input_density = (
|
||||
sparsification_dict.get('W_ir', [1])[0]
|
||||
+ sparsification_dict.get('W_in', [1])[0]
|
||||
+ sparsification_dict.get('W_iz', [1])[0]
|
||||
) / 3
|
||||
|
||||
recurrent_density = (
|
||||
sparsification_dict.get('W_hr', [1])[0]
|
||||
+ sparsification_dict.get('W_hn', [1])[0]
|
||||
+ sparsification_dict.get('W_hz', [1])[0]
|
||||
) / 3
|
||||
|
||||
# input matrix vector multiplications
|
||||
if not drop_input:
|
||||
flops += 2 * 3 * input_size * hidden_size * input_density
|
||||
|
||||
# recurrent matrix vector multiplications
|
||||
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
|
||||
|
||||
# biases
|
||||
flops += 6 * hidden_size
|
||||
|
||||
# activations estimated by 10 flops per activation
|
||||
flops += 30 * hidden_size
|
||||
|
||||
return flops
|
||||
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .common import sparsify_matrix
|
||||
|
||||
|
||||
class GRUSparsifier:
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
|
||||
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
|
||||
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
|
||||
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
|
||||
should be kept.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> gru = torch.nn.GRU(10, 20)
|
||||
>>> sparsify_dict = {
|
||||
... 'W_ir' : (0.5, [2, 2], False),
|
||||
... 'W_iz' : (0.6, [2, 2], False),
|
||||
... 'W_in' : (0.7, [2, 2], False),
|
||||
... 'W_hr' : (0.1, [4, 4], True),
|
||||
... 'W_hz' : (0.2, [4, 4], True),
|
||||
... 'W_hn' : (0.3, [4, 4], True),
|
||||
... }
|
||||
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
# just copying parameters...
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.interval = interval
|
||||
self.exponent = exponent
|
||||
self.task_list = task_list
|
||||
|
||||
# ... and setting counter to 0
|
||||
self.step_counter = 0
|
||||
|
||||
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
|
||||
|
||||
def step(self, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
# compute current interpolation factor
|
||||
self.step_counter += 1
|
||||
|
||||
if self.step_counter < self.start:
|
||||
return
|
||||
elif self.step_counter < self.stop:
|
||||
# update only every self.interval-th interval
|
||||
if self.step_counter % self.interval:
|
||||
return
|
||||
|
||||
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
|
||||
else:
|
||||
alpha = 0
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
for gru, params in self.task_list:
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
# input weights
|
||||
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
|
||||
if key in params:
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
|
||||
gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density, # density
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
|
||||
print(f"sparsification mask {key} changed for gru {gru}")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
# recurrent weights
|
||||
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
|
||||
if key in params:
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density,
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
|
||||
print(f"sparsification mask {key} changed for gru {gru}")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
gru = torch.nn.GRU(10, 20)
|
||||
sparsify_dict = {
|
||||
'W_ir' : (0.5, [2, 2], False),
|
||||
'W_iz' : (0.6, [2, 2], False),
|
||||
'W_in' : (0.7, [2, 2], False),
|
||||
'W_hr' : (0.1, [4, 4], True),
|
||||
'W_hz' : (0.2, [4, 4], True),
|
||||
'W_hn' : (0.3, [4, 4], True),
|
||||
}
|
||||
|
||||
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
from models import multi_rate_lpcnet
|
||||
import copy
|
||||
|
||||
setup_dict = dict()
|
||||
|
||||
dataset_template_v2 = {
|
||||
'version' : 2,
|
||||
'feature_file' : 'features.f32',
|
||||
'signal_file' : 'data.s16',
|
||||
'frame_length' : 160,
|
||||
'feature_frame_length' : 36,
|
||||
'signal_frame_length' : 2,
|
||||
'feature_dtype' : 'float32',
|
||||
'signal_dtype' : 'int16',
|
||||
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [18, 19], 'pitch_corr': [19, 20], 'lpc': [20, 36]},
|
||||
'signal_frame_layout' : {'last_signal' : 0, 'signal': 1} # signal, last_signal, error, prediction
|
||||
}
|
||||
|
||||
dataset_template_v1 = {
|
||||
'version' : 1,
|
||||
'feature_file' : 'features.f32',
|
||||
'signal_file' : 'data.u8',
|
||||
'frame_length' : 160,
|
||||
'feature_frame_length' : 55,
|
||||
'signal_frame_length' : 4,
|
||||
'feature_dtype' : 'float32',
|
||||
'signal_dtype' : 'uint8',
|
||||
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [36, 37], 'pitch_corr': [37, 38], 'lpc': [39, 55]},
|
||||
'signal_frame_layout' : {'last_signal' : 0, 'prediction' : 1, 'last_error': 2, 'error': 3} # signal, last_signal, error, prediction
|
||||
}
|
||||
|
||||
# lpcnet
|
||||
|
||||
lpcnet_config = {
|
||||
'frame_size' : 160,
|
||||
'gru_a_units' : 384,
|
||||
'gru_b_units' : 64,
|
||||
'feature_conditioning_dim' : 128,
|
||||
'feature_conv_kernel_size' : 3,
|
||||
'period_levels' : 257,
|
||||
'period_embedding_dim' : 64,
|
||||
'signal_embedding_dim' : 128,
|
||||
'signal_levels' : 256,
|
||||
'feature_dimension' : 19,
|
||||
'output_levels' : 256,
|
||||
'lpc_gamma' : 0.9,
|
||||
'features' : ['cepstrum', 'periods', 'pitch_corr'],
|
||||
'signals' : ['last_signal', 'prediction', 'last_error'],
|
||||
'input_layout' : { 'signals' : {'last_signal' : 0, 'prediction' : 1, 'last_error' : 2},
|
||||
'features' : {'cepstrum' : [0, 18], 'pitch_corr' : [18, 19]} },
|
||||
'target' : 'error',
|
||||
'feature_history' : 2,
|
||||
'feature_lookahead' : 2,
|
||||
'sparsification' : {
|
||||
'gru_a' : {
|
||||
'start' : 10000,
|
||||
'stop' : 30000,
|
||||
'interval' : 100,
|
||||
'exponent' : 3,
|
||||
'params' : {
|
||||
'W_hr' : (0.05, [4, 8], True),
|
||||
'W_hz' : (0.05, [4, 8], True),
|
||||
'W_hn' : (0.2, [4, 8], True)
|
||||
},
|
||||
},
|
||||
'gru_b' : {
|
||||
'start' : 10000,
|
||||
'stop' : 30000,
|
||||
'interval' : 100,
|
||||
'exponent' : 3,
|
||||
'params' : {
|
||||
'W_ir' : (0.5, [4, 8], False),
|
||||
'W_iz' : (0.5, [4, 8], False),
|
||||
'W_in' : (0.5, [4, 8], False)
|
||||
},
|
||||
}
|
||||
},
|
||||
'add_reference_phase' : False,
|
||||
'reference_phase_dim' : 0
|
||||
}
|
||||
|
||||
|
||||
|
||||
# multi rate
|
||||
subconditioning = {
|
||||
'subconditioning_a' : {
|
||||
'number_of_subsamples' : 2,
|
||||
'method' : 'modulative',
|
||||
'signals' : ['last_signal', 'prediction', 'last_error'],
|
||||
'pcm_embedding_size' : 64,
|
||||
'kwargs' : dict()
|
||||
|
||||
},
|
||||
'subconditioning_b' : {
|
||||
'number_of_subsamples' : 2,
|
||||
'method' : 'modulative',
|
||||
'signals' : ['last_signal', 'prediction', 'last_error'],
|
||||
'pcm_embedding_size' : 64,
|
||||
'kwargs' : dict()
|
||||
}
|
||||
}
|
||||
|
||||
multi_rate_lpcnet_config = lpcnet_config.copy()
|
||||
multi_rate_lpcnet_config['subconditioning'] = subconditioning
|
||||
|
||||
training_default = {
|
||||
'batch_size' : 256,
|
||||
'epochs' : 20,
|
||||
'lr' : 1e-3,
|
||||
'lr_decay_factor' : 2.5e-5,
|
||||
'adam_betas' : [0.9, 0.99],
|
||||
'frames_per_sample' : 15
|
||||
}
|
||||
|
||||
lpcnet_setup = {
|
||||
'dataset' : '/local/datasets/lpcnet_training',
|
||||
'lpcnet' : {'config' : lpcnet_config, 'model': 'lpcnet'},
|
||||
'training' : training_default
|
||||
}
|
||||
|
||||
multi_rate_lpcnet_setup = copy.deepcopy(lpcnet_setup)
|
||||
multi_rate_lpcnet_setup['lpcnet']['config'] = multi_rate_lpcnet_config
|
||||
multi_rate_lpcnet_setup['lpcnet']['model'] = 'multi_rate'
|
||||
|
||||
setup_dict = {
|
||||
'lpcnet' : lpcnet_setup,
|
||||
'multi_rate' : multi_rate_lpcnet_setup
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import math as m
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def ulaw2lin(u):
|
||||
scale_1 = 32768.0 / 255.0
|
||||
u = u - 128
|
||||
s = torch.sign(u)
|
||||
u = torch.abs(u)
|
||||
return s * scale_1 * (torch.exp(u / 128. * m.log(256)) - 1)
|
||||
|
||||
|
||||
def lin2ulawq(x):
|
||||
scale = 255.0 / 32768.0
|
||||
s = torch.sign(x)
|
||||
x = torch.abs(x)
|
||||
u = s * (128 * torch.log(1 + scale * x) / m.log(256))
|
||||
u = torch.clip(128 + torch.round(u), 0, 255)
|
||||
return u
|
||||
|
||||
def lin2ulaw(x):
|
||||
scale = 255.0 / 32768.0
|
||||
s = torch.sign(x)
|
||||
x = torch.abs(x)
|
||||
u = s * (128 * torch.log(1 + scale * x) / torch.log(256))
|
||||
u = torch.clip(128 + u, 0, 255)
|
||||
return u
|
||||
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import wave
|
||||
|
||||
def wavwrite16(filename, x, fs):
|
||||
""" writes x as int16 to file with name filename
|
||||
|
||||
If x.dtype is int16 x is written as is. Otherwise,
|
||||
it is scaled by 2**15 - 1 and converted to int16.
|
||||
"""
|
||||
if x.dtype != 'int16':
|
||||
x = ((2**15 - 1) * x).astype('int16')
|
||||
|
||||
with wave.open(filename, 'wb') as f:
|
||||
f.setparams((1, 2, fs, len(x), 'NONE', ""))
|
||||
f.writeframes(x.tobytes())
|
||||
@@ -0,0 +1,18 @@
|
||||
## Neural Pitch Estimation
|
||||
|
||||
- Dataset Installation
|
||||
1. Download and unzip PTDB Dataset:
|
||||
wget https://www2.spsc.tugraz.at/databases/PTDB-TUG/SPEECH_DATA_ZIPPED.zip
|
||||
unzip SPEECH_DATA_ZIPPED.zip
|
||||
|
||||
2. Inside "SPEECH DATA" above, run ptdb_process.sh to combine male/female
|
||||
|
||||
3. To Download and combine demand, simply run download_demand.sh
|
||||
|
||||
- LPCNet preparation
|
||||
1. To extract xcorr, add lpcnet_extractor.c and add relevant functions to lpcnet_enc.c, add source for headers/c files and Makefile.am, and compile to generate ./lpcnet_xcorr_extractor object
|
||||
|
||||
- Dataset Augmentation and training (check out arguments to each of the following)
|
||||
1. Run data_augmentation.py
|
||||
2. Run training.py using augmented data
|
||||
3. Run experiments.py
|
||||
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Perform Data Augmentation (Gain, Additive Noise, Random Filtering) on Input TTS Data
|
||||
1. Read in chunks and compute clean pitch first
|
||||
2. Then add in augmentation (Noise/Level/Response)
|
||||
- Adds filtered noise from the "Demand" dataset, https://zenodo.org/record/1227121#.XRKKxYhKiUk
|
||||
- When using the Demand Dataset, consider each channel as a possible noise input, and keep the first 4 minutes of noise for training
|
||||
3. Use this "augmented" audio for feature computation, and compute pitch using CREPE on the clean input
|
||||
|
||||
Notes: To ensure consistency with the discovered CREPE offset, we do the following
|
||||
- We pad the input audio to the zero-centered CREPE estimator with 80 zeros
|
||||
- We pad the input audio to our feature computation with 160 zeros to center them
|
||||
"""
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('data', type=str, help='input raw audio data')
|
||||
parser.add_argument('output', type=str, help='output directory')
|
||||
parser.add_argument('path_lpcnet_extractor', type=str, help='path to LPCNet extractor object file (generated on compilation)')
|
||||
parser.add_argument('noise_dataset', type=str, help='Location of the Demand Datset')
|
||||
parser.add_argument('--flag_xcorr', type=bool, help='Flag to additionally dump xcorr features',choices=[True,False],default = False,required = False)
|
||||
parser.add_argument('--fraction_input_use', type=float, help='Fraction of input data to consider',default = 0.3,required = False)
|
||||
parser.add_argument('--gpu_index', type=int, help='GPU index to use if multiple GPUs',default = 0,required = False)
|
||||
parser.add_argument('--choice_augment', type=str, help='Choice of noise augmentation, either use additive synthetic noise or add noise from the demand dataset',choices = ['demand','synthetic'],default = "demand",required = False)
|
||||
parser.add_argument('--fraction_clean', type=float, help='Fraction of data to keep clean (that is not augment with anything)',default = 0.2,required = False)
|
||||
parser.add_argument('--chunk_size', type=int, help='Number of samples to augment with for each iteration',default = 80000,required = False)
|
||||
parser.add_argument('--N', type=int, help='STFT window size',default = 320,required = False)
|
||||
parser.add_argument('--H', type=int, help='STFT Hop size',default = 160,required = False)
|
||||
parser.add_argument('--freq_keep', type=int, help='Number of Frequencies to keep',default = 30,required = False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)
|
||||
|
||||
from utils import stft, random_filter
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import crepe
|
||||
import random
|
||||
import glob
|
||||
import subprocess
|
||||
|
||||
data_full = np.memmap(args.data, dtype=np.int16,mode = 'r')
|
||||
data = data_full[:(int)(args.fraction_input_use*data_full.shape[0])]
|
||||
|
||||
# list_features = []
|
||||
list_cents = []
|
||||
list_confidences = []
|
||||
|
||||
N = args.N
|
||||
H = args.H
|
||||
freq_keep = args.freq_keep
|
||||
# Minimum/Maximum periods, decided by LPCNet
|
||||
min_period = 32
|
||||
max_period = 256
|
||||
f_ref = 16000/max_period
|
||||
chunk_size = args.chunk_size
|
||||
num_frames_chunk = chunk_size//H
|
||||
list_indices_keep = np.concatenate([np.arange(freq_keep), (N//2 + 1) + np.arange(freq_keep), 2*(N//2 + 1) + np.arange(freq_keep)])
|
||||
|
||||
output_IF = np.memmap(args.output + '_iffeat.f32', dtype=np.float32, shape=(((data.shape[0]//chunk_size - 1)//1)*num_frames_chunk,list_indices_keep.shape[0]), mode='w+')
|
||||
if args.flag_xcorr:
|
||||
output_xcorr = np.memmap(args.output + '_xcorr.f32', dtype=np.float32, shape=(((data.shape[0]//chunk_size - 1)//1)*num_frames_chunk,257), mode='w+')
|
||||
|
||||
fraction_clean = args.fraction_clean
|
||||
|
||||
noise_dataset = args.noise_dataset
|
||||
|
||||
for i in tqdm.trange((data.shape[0]//chunk_size - 1)//1):
|
||||
chunk = data[i*chunk_size:(i + 1)*chunk_size]/(2**15 - 1)
|
||||
|
||||
# Clean Pitch/Confidence Estimate
|
||||
# Padding input to CREPE by 80 samples to ensure it aligns
|
||||
_, pitch, confidence, _ = crepe.predict(np.concatenate([np.zeros(80),chunk]), 16000, center=True, viterbi=True,verbose=0)
|
||||
cent = 1200*np.log2(np.divide(pitch, f_ref, out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)
|
||||
|
||||
# Filter out of range pitches/confidences
|
||||
confidence[pitch < 16000/max_period] = 0
|
||||
confidence[pitch > 16000/min_period] = 0
|
||||
|
||||
# Keep fraction of data clean, augment only 1 minus the fraction
|
||||
if (np.random.rand() > fraction_clean):
|
||||
# Response, generate controlled/random 2nd order IIR filter and filter chunk
|
||||
chunk = random_filter(chunk)
|
||||
|
||||
# Level/Gain response {scale by random gain between 1.0e-3 and 10}
|
||||
# Generate random gain in dB and then convert to scale
|
||||
g_dB = np.random.uniform(low = -60, high = 20, size = 1)
|
||||
# g_dB = 0
|
||||
g = 10**(g_dB/20)
|
||||
|
||||
# Noise Addition {Add random SNR 2nd order randomly colored noise}
|
||||
# Generate noise SNR value and add corresponding noise
|
||||
snr_dB = np.random.uniform(low = -20, high = 30, size = 1)
|
||||
|
||||
if args.choice_augment == 'synthetic':
|
||||
n = np.random.randn(chunk_size)
|
||||
else:
|
||||
list_noisefiles = noise_dataset + '*.wav'
|
||||
noise_file = random.choice(glob.glob(list_noisefiles))
|
||||
n = np.memmap(noise_file, dtype=np.int16,mode = 'r')/(2**15 - 1)
|
||||
rand_range = np.random.randint(low = 0, high = (n.shape[0] - 16000*60 - chunk.shape[0])) # 16000 is subtracted because we will use the last 1 minutes of noise for testing
|
||||
n = n[rand_range:rand_range + chunk.shape[0]]
|
||||
|
||||
# Randomly filter the sampled noise as well
|
||||
n = random_filter(n)
|
||||
# generate random prime number between 0,500 and make those samples of noise 0 (to prevent GRU from picking up temporal patterns)
|
||||
Nprime = random.choice([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541])
|
||||
n[chunk_size - Nprime:] = np.zeros(Nprime)
|
||||
snr_multiplier = np.sqrt((np.sum(np.abs(chunk)**2)/np.sum(np.abs(n)**2))*10**(-snr_dB/10))
|
||||
|
||||
chunk = g*(chunk + snr_multiplier*n)
|
||||
|
||||
# Zero pad input audio by 160 to center the frames
|
||||
spec = stft(x = np.concatenate([np.zeros(160),chunk]), w = 'boxcar', N = N, H = H).T
|
||||
phase_diff = spec*np.conj(np.roll(spec,1,axis = -1))
|
||||
phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8)
|
||||
feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T
|
||||
feature = feature[:,list_indices_keep]
|
||||
|
||||
if args.flag_xcorr:
|
||||
# Dump noisy audio into temp file
|
||||
data_temp = np.memmap('./temp_augment.raw', dtype=np.int16, shape=(chunk.shape[0]), mode='w+')
|
||||
# data_temp[:chunk.shape[0]] = (chunk/(np.max(np.abs(chunk)))*(2**15 - 1)).astype(np.int16)
|
||||
data_temp[:chunk.shape[0]] = ((chunk)*(2**15 - 1)).astype(np.int16)
|
||||
|
||||
subprocess.run([args.path_lpcnet_extractor, './temp_augment.raw', './temp_augment_xcorr.f32'])
|
||||
feature_xcorr = np.flip(np.fromfile('./temp_augment_xcorr.f32', dtype='float32').reshape((-1,256),order = 'C'),axis = 1)
|
||||
ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1)
|
||||
feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1)
|
||||
|
||||
os.remove('./temp_augment.raw')
|
||||
os.remove('./temp_augment_xcorr.f32')
|
||||
num_frames = min(cent.shape[0],feature.shape[0],feature_xcorr.shape[0],num_frames_chunk)
|
||||
feature = feature[:num_frames,:]
|
||||
cent = cent[:num_frames]
|
||||
confidence = confidence[:num_frames]
|
||||
feature_xcorr = feature_xcorr[:num_frames]
|
||||
output_IF[i*num_frames_chunk:(i + 1)*num_frames_chunk,:] = feature
|
||||
output_xcorr[i*num_frames_chunk:(i + 1)*num_frames_chunk,:] = feature_xcorr
|
||||
list_cents.append(cent)
|
||||
list_confidences.append(confidence)
|
||||
|
||||
list_cents = np.hstack(list_cents)
|
||||
list_confidences = np.hstack(list_confidences)
|
||||
|
||||
np.save(args.output + '_pitches',np.vstack([list_cents,list_confidences]))
|
||||
@@ -0,0 +1,43 @@
|
||||
wget https://zenodo.org/record/1227121/files/DKITCHEN_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/DLIVING_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/DWASHING_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/NFIELD_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/NPARK_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/NRIVER_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/OHALLWAY_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/OMEETING_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/OOFFICE_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/PCAFETER_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/PRESTO_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/PSTATION_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/TMETRO_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/TCAR_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/TBUS_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/STRAFFIC_16k.zip
|
||||
|
||||
wget https://zenodo.org/record/1227121/files/SPSQUARE_16k.zip
|
||||
|
||||
unzip '*.zip'
|
||||
|
||||
mkdir -p ./combined_demand_channels/
|
||||
for file in */*.wav; do
|
||||
parentdir="$(dirname "$file")"
|
||||
echo $parentdir
|
||||
fname="$(basename "$file")"
|
||||
cp $file ./combined_demand_channels/$parentdir+$fname
|
||||
done
|
||||
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
Evaluation script to compute the Raw Pitch Accuracy
|
||||
Procedure:
|
||||
- Look at all voiced frames in file
|
||||
- Compute number of pitches in those frames that lie within a 50 cent threshold
|
||||
RPA = (Total number of pitches within threshold summed across all files)/(Total number of voiced frames summed accross all files)
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
from prettytable import PrettyTable
|
||||
import numpy as np
|
||||
import glob
|
||||
import random
|
||||
import tqdm
|
||||
import torch
|
||||
import librosa
|
||||
import json
|
||||
from utils import stft, random_filter, feature_xform
|
||||
import subprocess
|
||||
import crepe
|
||||
|
||||
from models import PitchDNN, PitchDNNIF, PitchDNNXcorr
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def rca(reference,input,voicing,thresh = 25):
|
||||
idx_voiced = np.where(voicing != 0)[0]
|
||||
acc = np.where(np.abs(reference - input)[idx_voiced] < thresh)[0]
|
||||
return acc.shape[0]
|
||||
|
||||
def sweep_rca(reference,input,voicing,thresh = 25,ind_arr = np.arange(-10,10)):
|
||||
l = []
|
||||
for i in ind_arr:
|
||||
l.append(rca(reference,np.roll(input,i),voicing,thresh))
|
||||
l = np.array(l)
|
||||
|
||||
return np.max(l)
|
||||
|
||||
def rpa(model,device = 'cpu',data_format = 'if'):
|
||||
list_files = glob.glob('/home/ubuntu/Code/Datasets/SPEECH DATA/combined_mic_16k_raw/*.raw')
|
||||
dir_f0 = '/home/ubuntu/Code/Datasets/SPEECH DATA/combine_f0_ptdb/'
|
||||
# random_shuffle = list(np.random.permutation(len(list_files)))
|
||||
random.shuffle(list_files)
|
||||
list_files = list_files[:1000]
|
||||
|
||||
C_all = 0
|
||||
C_all_m = 0
|
||||
C_all_f = 0
|
||||
list_rca_model_all = []
|
||||
list_rca_male_all = []
|
||||
list_rca_female_all = []
|
||||
|
||||
thresh = 50
|
||||
N = 320
|
||||
H = 160
|
||||
freq_keep = 30
|
||||
|
||||
for idx in tqdm.trange(len(list_files)):
|
||||
audio_file = list_files[idx]
|
||||
file_name = os.path.basename(list_files[idx])[:-4]
|
||||
|
||||
audio = np.memmap(list_files[idx], dtype=np.int16)/(2**15 - 1)
|
||||
offset = 432
|
||||
audio = audio[offset:]
|
||||
rmse = np.squeeze(librosa.feature.rms(y = audio,frame_length = 320,hop_length = 160))
|
||||
|
||||
spec = stft(x = np.concatenate([np.zeros(160),audio]), w = 'boxcar', N = N, H = H).T
|
||||
phase_diff = spec*np.conj(np.roll(spec,1,axis = -1))
|
||||
phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8)
|
||||
idx_save = np.concatenate([np.arange(freq_keep),(N//2 + 1) + np.arange(freq_keep),2*(N//2 + 1) + np.arange(freq_keep)])
|
||||
feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T
|
||||
feature_if = feature[:,idx_save]
|
||||
|
||||
data_temp = np.memmap('./temp.raw', dtype=np.int16, shape=(audio.shape[0]), mode='w+')
|
||||
data_temp[:audio.shape[0]] = (audio/(np.max(np.abs(audio)))*(2**15 - 1)).astype(np.int16)
|
||||
|
||||
subprocess.run(["../../../lpcnet_xcorr_extractor", './temp.raw', './temp_xcorr.f32'])
|
||||
feature_xcorr = np.flip(np.fromfile('./temp_xcorr.f32', dtype='float32').reshape((-1,256),order = 'C'),axis = 1)
|
||||
ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1)
|
||||
feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1)
|
||||
# feature_xcorr = feature_xform(feature_xcorr)
|
||||
|
||||
os.remove('./temp.raw')
|
||||
os.remove('./temp_xcorr.f32')
|
||||
|
||||
if data_format == 'if':
|
||||
feature = feature_if
|
||||
elif data_format == 'xcorr':
|
||||
feature = feature_xcorr
|
||||
else:
|
||||
indmin = min(feature_if.shape[0],feature_xcorr.shape[0])
|
||||
feature = np.concatenate([feature_xcorr[:indmin,:],feature_if[:indmin,:]],-1)
|
||||
|
||||
|
||||
pitch_file_name = dir_f0 + "ref" + os.path.basename(list_files[idx])[3:-4] + ".f0"
|
||||
pitch = np.loadtxt(pitch_file_name)[:,0]
|
||||
voicing = np.loadtxt(pitch_file_name)[:,1]
|
||||
indmin = min(voicing.shape[0],rmse.shape[0],pitch.shape[0])
|
||||
pitch = pitch[:indmin]
|
||||
voicing = voicing[:indmin]
|
||||
rmse = rmse[:indmin]
|
||||
voicing = voicing*(rmse > 0.05*np.max(rmse))
|
||||
if "mic_F" in audio_file:
|
||||
idx_correct = np.where(pitch < 125)
|
||||
voicing[idx_correct] = 0
|
||||
|
||||
cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int')
|
||||
|
||||
|
||||
model_cents = model(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device))
|
||||
model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy()
|
||||
|
||||
num_frames = min(cent.shape[0],model_cents.shape[0])
|
||||
pitch = pitch[:num_frames]
|
||||
cent = cent[:num_frames]
|
||||
voicing = voicing[:num_frames]
|
||||
model_cents = model_cents[:num_frames]
|
||||
|
||||
voicing_all = np.copy(voicing)
|
||||
# Forcefully make regions where pitch is <65 or greater than 500 unvoiced for relevant accurate pitch comparisons for our model
|
||||
force_out_of_pitch = np.where(np.logical_or(pitch < 65,pitch > 500)==True)
|
||||
voicing_all[force_out_of_pitch] = 0
|
||||
C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
|
||||
|
||||
list_rca_model_all.append(rca(cent,model_cents,voicing_all,thresh))
|
||||
|
||||
if "mic_M" in audio_file:
|
||||
list_rca_male_all.append(rca(cent,model_cents,voicing_all,thresh))
|
||||
C_all_m = C_all_m + np.where(voicing_all != 0)[0].shape[0]
|
||||
else:
|
||||
list_rca_female_all.append(rca(cent,model_cents,voicing_all,thresh))
|
||||
C_all_f = C_all_f + np.where(voicing_all != 0)[0].shape[0]
|
||||
|
||||
list_rca_model_all = np.array(list_rca_model_all)
|
||||
list_rca_male_all = np.array(list_rca_male_all)
|
||||
list_rca_female_all = np.array(list_rca_female_all)
|
||||
|
||||
|
||||
x = PrettyTable()
|
||||
|
||||
x.field_names = ["Experiment", "Mean RPA"]
|
||||
x.add_row(["Both all pitches", np.sum(list_rca_model_all)/C_all])
|
||||
|
||||
x.add_row(["Male all pitches", np.sum(list_rca_male_all)/C_all_m])
|
||||
|
||||
x.add_row(["Female all pitches", np.sum(list_rca_female_all)/C_all_f])
|
||||
|
||||
print(x)
|
||||
|
||||
return None
|
||||
|
||||
def cycle_eval(checkpoint_list, noise_type = 'synthetic', noise_dataset = None, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = None,fraction = 0.1,thresh = 50):
|
||||
"""
|
||||
Cycle through SNR evaluation for list of checkpoints
|
||||
"""
|
||||
list_files = glob.glob(ptdb_dataset_path + 'combined_mic_16k/*.raw')
|
||||
dir_f0 = ptdb_dataset_path + 'combined_reference_f0/'
|
||||
random.shuffle(list_files)
|
||||
list_files = list_files[:(int)(fraction*len(list_files))]
|
||||
|
||||
dict_models = {}
|
||||
list_snr.append(np.inf)
|
||||
|
||||
for f in checkpoint_list:
|
||||
if (f!='crepe') and (f!='lpcnet'):
|
||||
|
||||
checkpoint = torch.load(f, map_location='cpu')
|
||||
dict_params = checkpoint['config']
|
||||
if dict_params['data_format'] == 'if':
|
||||
from models import large_if_ccode as model
|
||||
pitch_nn = PitchDNNIF(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim'])
|
||||
elif dict_params['data_format'] == 'xcorr':
|
||||
from models import large_xcorr as model
|
||||
pitch_nn = PitchDNNXcorr(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
|
||||
else:
|
||||
from models import large_joint as model
|
||||
pitch_nn = PitchDNN(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
|
||||
|
||||
pitch_nn.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
N = dict_params['window_size']
|
||||
H = dict_params['hop_factor']
|
||||
freq_keep = dict_params['freq_keep']
|
||||
|
||||
list_mean = []
|
||||
list_std = []
|
||||
for snr_dB in list_snr:
|
||||
C_all = 0
|
||||
C_correct = 0
|
||||
for idx in tqdm.trange(len(list_files)):
|
||||
audio_file = list_files[idx]
|
||||
file_name = os.path.basename(list_files[idx])[:-4]
|
||||
|
||||
audio = np.memmap(list_files[idx], dtype=np.int16)/(2**15 - 1)
|
||||
offset = 432
|
||||
audio = audio[offset:]
|
||||
rmse = np.squeeze(librosa.feature.rms(y = audio,frame_length = N,hop_length = H))
|
||||
|
||||
if noise_type != 'synthetic':
|
||||
list_noisefiles = noise_dataset + '*.wav'
|
||||
noise_file = random.choice(glob.glob(list_noisefiles))
|
||||
n = np.memmap(noise_file, dtype=np.int16,mode = 'r')/(2**15 - 1)
|
||||
rand_range = np.random.randint(low = 0, high = (16000*60*5 - audio.shape[0])) # Last 1 minute of noise used for testing
|
||||
n = n[rand_range:rand_range + audio.shape[0]]
|
||||
else:
|
||||
n = np.random.randn(audio.shape[0])
|
||||
n = random_filter(n)
|
||||
|
||||
snr_multiplier = np.sqrt((np.sum(np.abs(audio)**2)/np.sum(np.abs(n)**2))*10**(-snr_dB/10))
|
||||
audio = audio + snr_multiplier*n
|
||||
|
||||
spec = stft(x = np.concatenate([np.zeros(160),audio]), w = 'boxcar', N = N, H = H).T
|
||||
phase_diff = spec*np.conj(np.roll(spec,1,axis = -1))
|
||||
phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8)
|
||||
idx_save = np.concatenate([np.arange(freq_keep),(N//2 + 1) + np.arange(freq_keep),2*(N//2 + 1) + np.arange(freq_keep)])
|
||||
feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T
|
||||
feature_if = feature[:,idx_save]
|
||||
|
||||
data_temp = np.memmap('./temp.raw', dtype=np.int16, shape=(audio.shape[0]), mode='w+')
|
||||
# data_temp[:audio.shape[0]] = (audio/(np.max(np.abs(audio)))*(2**15 - 1)).astype(np.int16)
|
||||
data_temp[:audio.shape[0]] = ((audio)*(2**15 - 1)).astype(np.int16)
|
||||
|
||||
subprocess.run(["../../../lpcnet_xcorr_extractor", './temp.raw', './temp_xcorr.f32'])
|
||||
feature_xcorr = np.flip(np.fromfile('./temp_xcorr.f32', dtype='float32').reshape((-1,256),order = 'C'),axis = 1)
|
||||
ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1)
|
||||
feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1)
|
||||
|
||||
os.remove('./temp.raw')
|
||||
os.remove('./temp_xcorr.f32')
|
||||
|
||||
if dict_params['data_format'] == 'if':
|
||||
feature = feature_if
|
||||
elif dict_params['data_format'] == 'xcorr':
|
||||
feature = feature_xcorr
|
||||
else:
|
||||
indmin = min(feature_if.shape[0],feature_xcorr.shape[0])
|
||||
feature = np.concatenate([feature_xcorr[:indmin,:],feature_if[:indmin,:]],-1)
|
||||
|
||||
pitch_file_name = dir_f0 + "ref" + os.path.basename(list_files[idx])[3:-4] + ".f0"
|
||||
pitch = np.loadtxt(pitch_file_name)[:,0]
|
||||
voicing = np.loadtxt(pitch_file_name)[:,1]
|
||||
indmin = min(voicing.shape[0],rmse.shape[0],pitch.shape[0])
|
||||
pitch = pitch[:indmin]
|
||||
voicing = voicing[:indmin]
|
||||
rmse = rmse[:indmin]
|
||||
voicing = voicing*(rmse > 0.05*np.max(rmse))
|
||||
if "mic_F" in audio_file:
|
||||
idx_correct = np.where(pitch < 125)
|
||||
voicing[idx_correct] = 0
|
||||
|
||||
cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int')
|
||||
|
||||
model_cents = pitch_nn(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device))
|
||||
model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy()
|
||||
|
||||
num_frames = min(cent.shape[0],model_cents.shape[0])
|
||||
pitch = pitch[:num_frames]
|
||||
cent = cent[:num_frames]
|
||||
voicing = voicing[:num_frames]
|
||||
model_cents = model_cents[:num_frames]
|
||||
|
||||
voicing_all = np.copy(voicing)
|
||||
# Forcefully make regions where pitch is <65 or greater than 500 unvoiced for relevant accurate pitch comparisons for our model
|
||||
force_out_of_pitch = np.where(np.logical_or(pitch < 65,pitch > 500)==True)
|
||||
voicing_all[force_out_of_pitch] = 0
|
||||
C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
|
||||
|
||||
C_correct = C_correct + rca(cent,model_cents,voicing_all,thresh)
|
||||
list_mean.append(C_correct/C_all)
|
||||
else:
|
||||
fname = f
|
||||
list_mean = []
|
||||
list_std = []
|
||||
for snr_dB in list_snr:
|
||||
C_all = 0
|
||||
C_correct = 0
|
||||
for idx in tqdm.trange(len(list_files)):
|
||||
audio_file = list_files[idx]
|
||||
file_name = os.path.basename(list_files[idx])[:-4]
|
||||
|
||||
audio = np.memmap(list_files[idx], dtype=np.int16)/(2**15 - 1)
|
||||
offset = 432
|
||||
audio = audio[offset:]
|
||||
rmse = np.squeeze(librosa.feature.rms(y = audio,frame_length = 320,hop_length = 160))
|
||||
|
||||
if noise_type != 'synthetic':
|
||||
list_noisefiles = noise_dataset + '*.wav'
|
||||
noise_file = random.choice(glob.glob(list_noisefiles))
|
||||
n = np.memmap(noise_file, dtype=np.int16,mode = 'r')/(2**15 - 1)
|
||||
rand_range = np.random.randint(low = 0, high = (16000*60*5 - audio.shape[0])) # Last 1 minute of noise used for testing
|
||||
n = n[rand_range:rand_range + audio.shape[0]]
|
||||
else:
|
||||
n = np.random.randn(audio.shape[0])
|
||||
n = random_filter(n)
|
||||
|
||||
snr_multiplier = np.sqrt((np.sum(np.abs(audio)**2)/np.sum(np.abs(n)**2))*10**(-snr_dB/10))
|
||||
audio = audio + snr_multiplier*n
|
||||
|
||||
if (f == 'crepe'):
|
||||
_, model_frequency, _, _ = crepe.predict(np.concatenate([np.zeros(80),audio]), 16000, viterbi=True,center=True,verbose=0)
|
||||
model_cents = 1200*np.log2(model_frequency/(16000/256) + 1.0e-8)
|
||||
else:
|
||||
data_temp = np.memmap('./temp.raw', dtype=np.int16, shape=(audio.shape[0]), mode='w+')
|
||||
# data_temp[:audio.shape[0]] = (audio/(np.max(np.abs(audio)))*(2**15 - 1)).astype(np.int16)
|
||||
data_temp[:audio.shape[0]] = ((audio)*(2**15 - 1)).astype(np.int16)
|
||||
|
||||
subprocess.run(["../../../lpcnet_xcorr_extractor", './temp.raw', './temp_xcorr.f32', './temp_period.f32'])
|
||||
feature_xcorr = np.fromfile('./temp_period.f32', dtype='float32')
|
||||
model_cents = 1200*np.log2((256/feature_xcorr + 1.0e-8) + 1.0e-8)
|
||||
|
||||
os.remove('./temp.raw')
|
||||
os.remove('./temp_xcorr.f32')
|
||||
os.remove('./temp_period.f32')
|
||||
|
||||
|
||||
pitch_file_name = dir_f0 + "ref" + os.path.basename(list_files[idx])[3:-4] + ".f0"
|
||||
pitch = np.loadtxt(pitch_file_name)[:,0]
|
||||
voicing = np.loadtxt(pitch_file_name)[:,1]
|
||||
indmin = min(voicing.shape[0],rmse.shape[0],pitch.shape[0])
|
||||
pitch = pitch[:indmin]
|
||||
voicing = voicing[:indmin]
|
||||
rmse = rmse[:indmin]
|
||||
voicing = voicing*(rmse > 0.05*np.max(rmse))
|
||||
if "mic_F" in audio_file:
|
||||
idx_correct = np.where(pitch < 125)
|
||||
voicing[idx_correct] = 0
|
||||
|
||||
cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int')
|
||||
num_frames = min(cent.shape[0],model_cents.shape[0])
|
||||
pitch = pitch[:num_frames]
|
||||
cent = cent[:num_frames]
|
||||
voicing = voicing[:num_frames]
|
||||
model_cents = model_cents[:num_frames]
|
||||
|
||||
voicing_all = np.copy(voicing)
|
||||
# Forcefully make regions where pitch is <65 or greater than 500 unvoiced for relevant accurate pitch comparisons for our model
|
||||
force_out_of_pitch = np.where(np.logical_or(pitch < 65,pitch > 500)==True)
|
||||
voicing_all[force_out_of_pitch] = 0
|
||||
C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
|
||||
|
||||
C_correct = C_correct + rca(cent,model_cents,voicing_all,thresh)
|
||||
list_mean.append(C_correct/C_all)
|
||||
dict_models[fname] = {}
|
||||
dict_models[fname]['list_SNR'] = list_mean[:-1]
|
||||
dict_models[fname]['inf'] = list_mean[-1]
|
||||
|
||||
return dict_models
|
||||
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Running the experiments;
|
||||
1. RCA vs SNR for our models, CREPE, LPCNet
|
||||
"""
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('ptdb_root', type=str, help='Root Directory for PTDB generated by running ptdb_process.sh ')
|
||||
parser.add_argument('output', type=str, help='Output dump file name')
|
||||
parser.add_argument('method', type=str, help='Output Directory to save experiment dumps',choices=['model','lpcnet','crepe'])
|
||||
parser.add_argument('--noise_dataset', type=str, help='Location of the Demand Datset',default = './',required=False)
|
||||
parser.add_argument('--noise_type', type=str, help='Type of additive noise',default = 'synthetic',choices=['synthetic','demand'],required=False)
|
||||
parser.add_argument('--pth_file', type=str, help='.pth file to analyze',default = './',required = False)
|
||||
parser.add_argument('--fraction_files_analyze', type=float, help='Fraction of PTDB dataset to test on',default = 1,required = False)
|
||||
parser.add_argument('--threshold_rca', type=float, help='Cent threshold when computing RCA',default = 50,required = False)
|
||||
parser.add_argument('--gpu_index', type=int, help='GPU index to use if multiple GPUs',default = 0,required = False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import os
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)
|
||||
|
||||
import json
|
||||
from evaluation import cycle_eval
|
||||
|
||||
if args.method == 'model':
|
||||
dict_store = cycle_eval([args.pth_file], noise_type = args.noise_type, noise_dataset = args.noise_dataset, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = args.ptdb_root,fraction = args.fraction_files_analyze,thresh = args.threshold_rca)
|
||||
else:
|
||||
dict_store = cycle_eval([args.method], noise_type = args.noise_type, noise_dataset = args.noise_dataset, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = args.ptdb_root,fraction = args.fraction_files_analyze,thresh = args.threshold_rca)
|
||||
|
||||
dict_store["method"] = args.method
|
||||
if args.method == 'model':
|
||||
dict_store['pth'] = args.pth_file
|
||||
|
||||
with open(args.output, 'w') as fp:
|
||||
json.dump(dict_store, fp)
|
||||
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('checkpoint', type=str, help='model checkpoint')
|
||||
parser.add_argument('output_dir', type=str, help='output folder')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from models import PitchDNN
|
||||
from wexchange.torch import dump_torch_weights
|
||||
from wexchange.c_export import CWriter, print_vector
|
||||
|
||||
def c_export(args, model):
|
||||
|
||||
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
|
||||
|
||||
writer = CWriter(os.path.join(args.output_dir, "pitchdnn_data"), message=message, model_struct_name='PitchDNN')
|
||||
writer.header.write(
|
||||
f"""
|
||||
#include "opus_types.h"
|
||||
"""
|
||||
)
|
||||
|
||||
dense_layers = [
|
||||
('if_upsample.0', "dense_if_upsampler_1"),
|
||||
('if_upsample.2', "dense_if_upsampler_2"),
|
||||
('downsample.0', "dense_downsampler"),
|
||||
("upsample.0", "dense_final_upsampler")
|
||||
]
|
||||
|
||||
|
||||
for name, export_name in dense_layers:
|
||||
layer = model.get_submodule(name)
|
||||
dump_torch_weights(writer, layer, name=export_name, verbose=True, quantize=True, scale=None)
|
||||
|
||||
conv_layers = [
|
||||
('conv.1', "conv2d_1"),
|
||||
('conv.4', "conv2d_2")
|
||||
]
|
||||
|
||||
|
||||
for name, export_name in conv_layers:
|
||||
layer = model.get_submodule(name)
|
||||
dump_torch_weights(writer, layer, name=export_name, verbose=True)
|
||||
|
||||
|
||||
gru_layers = [
|
||||
("GRU", "gru_1"),
|
||||
]
|
||||
|
||||
max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=True, scale=None, recurrent_scale=None)
|
||||
for name, export_name in gru_layers])
|
||||
|
||||
writer.header.write(
|
||||
f"""
|
||||
|
||||
#define PITCH_DNN_MAX_RNN_UNITS {max_rnn_units}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
model = PitchDNN()
|
||||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
c_export(args, model)
|
||||
178
managed_components/78__esp-opus/dnn/torch/neural-pitch/models.py
Normal file
178
managed_components/78__esp-opus/dnn/torch/neural-pitch/models.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Pitch Estimation Models and dataloaders
|
||||
- Classification Based (Input features, output logits)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class PitchDNNIF(torch.nn.Module):
|
||||
|
||||
def __init__(self, input_dim=88, gru_dim=64, output_dim=192):
|
||||
super().__init__()
|
||||
|
||||
self.activation = torch.nn.Tanh()
|
||||
self.initial = torch.nn.Linear(input_dim, gru_dim)
|
||||
self.hidden = torch.nn.Linear(gru_dim, gru_dim)
|
||||
self.gru = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, batch_first=True)
|
||||
self.upsample = torch.nn.Linear(gru_dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.initial(x)
|
||||
x = self.activation(x)
|
||||
x = self.hidden(x)
|
||||
x = self.activation(x)
|
||||
x,_ = self.gru(x)
|
||||
x = self.upsample(x)
|
||||
x = self.activation(x)
|
||||
x = x.permute(0,2,1)
|
||||
|
||||
return x
|
||||
|
||||
class PitchDNNXcorr(torch.nn.Module):
|
||||
|
||||
def __init__(self, input_dim=90, gru_dim=64, output_dim=192):
|
||||
super().__init__()
|
||||
|
||||
self.activation = torch.nn.Tanh()
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.ZeroPad2d((2, 0, 1, 1)),
|
||||
torch.nn.Conv2d(1, 8, 3, bias=True),
|
||||
self.activation,
|
||||
torch.nn.ZeroPad2d((2,0,1,1)),
|
||||
torch.nn.Conv2d(8, 8, 3, bias=True),
|
||||
self.activation,
|
||||
torch.nn.ZeroPad2d((2,0,1,1)),
|
||||
torch.nn.Conv2d(8, 1, 3, bias=True),
|
||||
self.activation,
|
||||
)
|
||||
|
||||
self.downsample = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_dim, gru_dim),
|
||||
self.activation
|
||||
)
|
||||
self.GRU = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, num_layers=1, batch_first=True)
|
||||
self.upsample = torch.nn.Sequential(
|
||||
torch.nn.Linear(gru_dim,output_dim),
|
||||
self.activation
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x.unsqueeze(-1).permute(0,3,2,1)).squeeze(1)
|
||||
x,_ = self.GRU(self.downsample(x.permute(0,2,1)))
|
||||
x = self.upsample(x).permute(0,2,1)
|
||||
|
||||
return x
|
||||
|
||||
class PitchDNN(torch.nn.Module):
|
||||
"""
|
||||
Joint IF-xcorr
|
||||
1D CNN on IF, merge with xcorr, 2D CNN on merged + GRU
|
||||
"""
|
||||
|
||||
def __init__(self,input_IF_dim=88, input_xcorr_dim=224, gru_dim=64, output_dim=192):
|
||||
super().__init__()
|
||||
|
||||
self.activation = torch.nn.Tanh()
|
||||
|
||||
self.if_upsample = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_IF_dim,64),
|
||||
self.activation,
|
||||
torch.nn.Linear(64,64),
|
||||
self.activation,
|
||||
)
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.ZeroPad2d((2,0,1,1)),
|
||||
torch.nn.Conv2d(1, 4, 3, bias=True),
|
||||
self.activation,
|
||||
torch.nn.ZeroPad2d((2,0,1,1)),
|
||||
torch.nn.Conv2d(4, 1, 3, bias=True),
|
||||
self.activation,
|
||||
)
|
||||
|
||||
self.downsample = torch.nn.Sequential(
|
||||
torch.nn.Linear(64 + input_xcorr_dim, gru_dim),
|
||||
self.activation
|
||||
)
|
||||
self.GRU = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, num_layers=1, batch_first=True)
|
||||
self.upsample = torch.nn.Sequential(
|
||||
torch.nn.Linear(gru_dim, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
xcorr_feat = x[:,:,:224]
|
||||
if_feat = x[:,:,224:]
|
||||
xcorr_feat = self.conv(xcorr_feat.unsqueeze(-1).permute(0,3,2,1)).squeeze(1).permute(0,2,1)
|
||||
if_feat = self.if_upsample(if_feat)
|
||||
x = torch.cat([xcorr_feat,if_feat],axis = - 1)
|
||||
x,_ = self.GRU(self.downsample(x))
|
||||
x = self.upsample(x).permute(0,2,1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Dataloaders
|
||||
class Loader(torch.utils.data.Dataset):
|
||||
def __init__(self, features_if, file_pitch, confidence_threshold=0.4, dimension_if=30, context=100):
|
||||
self.if_feat = np.memmap(features_if, dtype=np.float32).reshape(-1,3*dimension_if)
|
||||
|
||||
# Resolution of 20 cents
|
||||
self.cents = np.rint(np.load(file_pitch)[0,:]/20)
|
||||
self.cents = np.clip(self.cents,0,179)
|
||||
self.confidence = np.load(file_pitch)[1,:]
|
||||
|
||||
# Filter confidence for CREPE
|
||||
self.confidence[self.confidence < confidence_threshold] = 0
|
||||
self.context = context
|
||||
# Clip both to same size
|
||||
size_common = min(self.if_feat.shape[0], self.cents.shape[0])
|
||||
self.if_feat = self.if_feat[:size_common,:]
|
||||
self.cents = self.cents[:size_common]
|
||||
self.confidence = self.confidence[:size_common]
|
||||
|
||||
frame_max = self.if_feat.shape[0]//context
|
||||
self.if_feat = np.reshape(self.if_feat[:frame_max*context, :],(frame_max, context,3*dimension_if))
|
||||
self.cents = np.reshape(self.cents[:frame_max * context],(frame_max, context))
|
||||
self.confidence = np.reshape(self.confidence[:frame_max*context],(frame_max, context))
|
||||
|
||||
def __len__(self):
|
||||
return self.if_feat.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
return torch.from_numpy(self.if_feat[index,:,:]), torch.from_numpy(self.cents[index]), torch.from_numpy(self.confidence[index])
|
||||
|
||||
class PitchDNNDataloader(torch.utils.data.Dataset):
|
||||
def __init__(self, features, file_pitch, confidence_threshold=0.4, context=100, choice_data='both'):
|
||||
self.feat = np.memmap(features, mode='r', dtype=np.int8).reshape(-1,312)
|
||||
self.xcorr = self.feat[:,:224]
|
||||
self.if_feat = self.feat[:,224:]
|
||||
ground_truth = np.memmap(file_pitch, mode='r', dtype=np.float32).reshape(-1,2)
|
||||
self.cents = np.rint(60*np.log2(ground_truth[:,0]/62.5))
|
||||
mask = (self.cents>=0).astype('float32') * (self.cents<=180).astype('float32')
|
||||
self.cents = np.clip(self.cents,0,179)
|
||||
self.confidence = ground_truth[:,1] * mask
|
||||
# Filter confidence for CREPE
|
||||
self.confidence[self.confidence < confidence_threshold] = 0
|
||||
self.context = context
|
||||
|
||||
self.choice_data = choice_data
|
||||
|
||||
frame_max = self.if_feat.shape[0]//context
|
||||
self.if_feat = np.reshape(self.if_feat[:frame_max*context,:], (frame_max, context, 88))
|
||||
self.cents = np.reshape(self.cents[:frame_max*context], (frame_max,context))
|
||||
self.xcorr = np.reshape(self.xcorr[:frame_max*context,:], (frame_max,context, 224))
|
||||
self.confidence = np.reshape(self.confidence[:frame_max*context], (frame_max, context))
|
||||
|
||||
def __len__(self):
|
||||
return self.if_feat.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.choice_data == 'both':
|
||||
return torch.cat([torch.from_numpy((1./127)*self.xcorr[index,:,:]), torch.from_numpy((1./127)*self.if_feat[index,:,:])], dim=-1), torch.from_numpy(self.cents[index]), torch.from_numpy(self.confidence[index])
|
||||
elif self.choice_data == 'if':
|
||||
return torch.from_numpy((1./127)*self.if_feat[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
|
||||
else:
|
||||
return torch.from_numpy((1./127)*self.xcorr[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
|
||||
@@ -0,0 +1,179 @@
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('features', type=str, help='Features generated from dump_data')
|
||||
parser.add_argument('data', type=str, help='Data generated from dump_data (offset by 5ms)')
|
||||
parser.add_argument('output', type=str, help='output .f32 feature file with replaced neural pitch')
|
||||
parser.add_argument('checkpoint', type=str, help='model checkpoint file')
|
||||
parser.add_argument('path_lpcnet_extractor', type=str, help='path to LPCNet extractor object file (generated on compilation)')
|
||||
parser.add_argument('--device', type=str, help='compute device',default = None,required = False)
|
||||
parser.add_argument('--replace_xcorr', type = bool, default = False, help='Replace LPCNet xcorr with updated one')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import os
|
||||
|
||||
from utils import stft, random_filter
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import json
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from models import PitchDNNIF, PitchDNNXcorr, PitchDNN
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if device is not None:
|
||||
device = torch.device(args.device)
|
||||
|
||||
# Loading the appropriate model
|
||||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
dict_params = checkpoint['config']
|
||||
|
||||
if dict_params['data_format'] == 'if':
|
||||
pitch_nn = PitchDNNIF(dict_params['freq_keep']*3, dict_params['gru_dim'], dict_params['output_dim'])
|
||||
elif dict_params['data_format'] == 'xcorr':
|
||||
pitch_nn = PitchDNNXcorr(dict_params['xcorr_dim'], dict_params['gru_dim'], dict_params['output_dim'])
|
||||
else:
|
||||
pitch_nn = PitchDNN(dict_params['freq_keep']*3, dict_params['xcorr_dim'], dict_params['gru_dim'], dict_params['output_dim'])
|
||||
|
||||
pitch_nn.load_state_dict(checkpoint['state_dict'])
|
||||
pitch_nn = pitch_nn.to(device)
|
||||
|
||||
N = dict_params['window_size']
|
||||
H = dict_params['hop_factor']
|
||||
freq_keep = dict_params['freq_keep']
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = "16"
|
||||
|
||||
|
||||
def run_lpc(signal, lpcs, frame_length=160):
|
||||
num_frames, lpc_order = lpcs.shape
|
||||
|
||||
prediction = np.concatenate(
|
||||
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
|
||||
)
|
||||
error = signal[lpc_order :] - prediction
|
||||
|
||||
return prediction, error
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
features = np.memmap(args.features, dtype=np.float32,mode = 'r').reshape((-1, 36))
|
||||
data = np.memmap(args.data, dtype=np.int16,mode = 'r').reshape((-1, 2))
|
||||
|
||||
num_frames = features.shape[0]
|
||||
feature_dim = features.shape[1]
|
||||
|
||||
assert feature_dim == 36
|
||||
|
||||
output = np.memmap(args.output, dtype=np.float32, shape=(num_frames, feature_dim), mode='w+')
|
||||
output[:, :36] = features
|
||||
|
||||
# lpc coefficients and signal
|
||||
lpcs = features[:, 20:36]
|
||||
sig = data[:, 1]
|
||||
|
||||
# parameters
|
||||
|
||||
# constants
|
||||
pitch_min = 32
|
||||
pitch_max = 256
|
||||
lpc_order = 16
|
||||
fs = 16000
|
||||
frame_length = 160
|
||||
overlap_frames = 100
|
||||
chunk_size = 10000
|
||||
history_length = frame_length * overlap_frames
|
||||
history = np.zeros(history_length, dtype=np.int16)
|
||||
pitch_position=18
|
||||
xcorr_position=19
|
||||
conf_position=36
|
||||
|
||||
num_frames = len(sig) // 160 - 1
|
||||
|
||||
frame_start = 0
|
||||
frame_stop = min(frame_start + chunk_size, num_frames)
|
||||
signal_start = 0
|
||||
signal_stop = frame_stop * frame_length
|
||||
|
||||
niters = (num_frames - 1)//chunk_size
|
||||
for i in tqdm.trange(niters):
|
||||
if (frame_start > num_frames - 1):
|
||||
break
|
||||
chunk = np.concatenate((history, sig[signal_start:signal_stop]))
|
||||
chunk_la = np.concatenate((history, sig[signal_start:signal_stop + 80]))
|
||||
|
||||
# Feature computation
|
||||
spec = stft(x = np.concatenate([np.zeros(80),chunk_la/(2**15 - 1)]), w = 'boxcar', N = N, H = H).T
|
||||
phase_diff = spec*np.conj(np.roll(spec,1,axis = -1))
|
||||
phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8)
|
||||
idx_save = np.concatenate([np.arange(freq_keep),(N//2 + 1) + np.arange(freq_keep),2*(N//2 + 1) + np.arange(freq_keep)])
|
||||
feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T
|
||||
feature_if = feature[:,idx_save]
|
||||
|
||||
data_temp = np.memmap('./temp_featcompute_' + dict_params['data_format'] + '_.raw', dtype=np.int16, shape=(chunk.shape[0]), mode='w+')
|
||||
data_temp[:chunk.shape[0]] = chunk_la[80:].astype(np.int16)
|
||||
|
||||
subprocess.run([args.path_lpcnet_extractor, './temp_featcompute_' + dict_params['data_format'] + '_.raw', './temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw'])
|
||||
feature_xcorr = np.flip(np.fromfile('./temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw', dtype='float32').reshape((-1,256),order = 'C'),axis = 1)
|
||||
ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1)
|
||||
feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1)
|
||||
|
||||
os.remove('./temp_featcompute_' + dict_params['data_format'] + '_.raw')
|
||||
os.remove('./temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw')
|
||||
|
||||
if dict_params['data_format'] == 'if':
|
||||
feature = feature_if
|
||||
elif dict_params['data_format'] == 'xcorr':
|
||||
feature = feature_xcorr
|
||||
else:
|
||||
indmin = min(feature_if.shape[0],feature_xcorr.shape[0])
|
||||
feature = np.concatenate([feature_xcorr[:indmin,:],feature_if[:indmin,:]],-1)
|
||||
|
||||
# Compute pitch with my model
|
||||
model_cents = pitch_nn(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device))
|
||||
model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy()
|
||||
frequency = 62.5*2**(model_cents/1200)
|
||||
|
||||
frequency = frequency[overlap_frames : overlap_frames + frame_stop - frame_start]
|
||||
|
||||
# convert frequencies to periods
|
||||
periods = np.round(fs / frequency)
|
||||
|
||||
periods = np.clip(periods, pitch_min, pitch_max)
|
||||
|
||||
output[frame_start:frame_stop, pitch_position] = (periods - 100) / 50
|
||||
|
||||
frame_offset = (pitch_max + frame_length - 1) // frame_length
|
||||
offset = frame_offset * frame_length
|
||||
padding = lpc_order
|
||||
|
||||
|
||||
if frame_start < frame_offset:
|
||||
lpc_coeffs = np.concatenate((np.zeros((frame_offset - frame_start, lpc_order), dtype=np.float32), lpcs[:frame_stop]))
|
||||
else:
|
||||
lpc_coeffs = lpcs[frame_start - frame_offset : frame_stop]
|
||||
|
||||
pred, error = run_lpc(chunk[history_length - offset - padding :], lpc_coeffs, frame_length=frame_length)
|
||||
|
||||
xcorr = np.zeros(frame_stop - frame_start)
|
||||
for i, p in enumerate(periods.astype(np.int16)):
|
||||
if p > 0:
|
||||
f1 = error[offset + i * frame_length : offset + (i + 1) * frame_length]
|
||||
f2 = error[offset + i * frame_length - p : offset + (i + 1) * frame_length - p]
|
||||
xcorr[i] = np.dot(f1, f2) / np.sqrt(np.dot(f1, f1) * np.dot(f2, f2) + 1e-6)
|
||||
|
||||
output[frame_start:frame_stop, xcorr_position] = xcorr - 0.5
|
||||
|
||||
# update buffers and indices
|
||||
history = chunk[-history_length :]
|
||||
|
||||
frame_start += chunk_size
|
||||
frame_stop += chunk_size
|
||||
frame_stop = min(frame_stop, num_frames)
|
||||
|
||||
signal_start = frame_start * frame_length
|
||||
signal_stop = frame_stop * frame_length
|
||||
@@ -0,0 +1,34 @@
|
||||
# Copy into PTDB root directory and run to combine all the male/female raw audio/references into below directories
|
||||
|
||||
# Make folder for combined audio
|
||||
mkdir -p './combined_mic_16k/'
|
||||
# Make folder for combined pitch reference
|
||||
mkdir -p './combined_reference_f0/'
|
||||
|
||||
# Resample Male Audio
|
||||
for i in ./MALE/MIC/**/*.wav; do
|
||||
j="$(basename "$i" .wav)"
|
||||
echo $j
|
||||
sox -r 48000 -b 16 -e signed-integer "$i" -r 16000 -b 16 -e signed-integer ./combined_mic_16k/$j.raw
|
||||
done
|
||||
|
||||
# Resample Female Audio
|
||||
for i in ./FEMALE/MIC/**/*.wav; do
|
||||
j="$(basename "$i" .wav)"
|
||||
echo $j
|
||||
sox -r 48000 -b 16 -e signed-integer "$i" -r 16000 -b 16 -e signed-integer ./combined_mic_16k/$j.raw
|
||||
done
|
||||
|
||||
# Shift Male reference pitch files
|
||||
for i in ./MALE/REF/**/*.f0; do
|
||||
j="$(basename "$i" .wav)"
|
||||
echo $j
|
||||
cp "$i" ./combined_reference_f0/
|
||||
done
|
||||
|
||||
# Shift Female reference pitch files
|
||||
for i in ./FEMALE/REF/**/*.f0; do
|
||||
j="$(basename "$i" .wav)"
|
||||
echo $j
|
||||
cp "$i" ./combined_reference_f0/
|
||||
done
|
||||
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Perform Data Augmentation (Gain, Additive Noise, Random Filtering) on Input TTS Data
|
||||
1. Read in chunks and compute clean pitch first
|
||||
2. Then add in augmentation (Noise/Level/Response)
|
||||
- Adds filtered noise from the "Demand" dataset, https://zenodo.org/record/1227121#.XRKKxYhKiUk
|
||||
- When using the Demand Dataset, consider each channel as a possible noise input, and keep the first 4 minutes of noise for training
|
||||
3. Use this "augmented" audio for feature computation, and compute pitch using CREPE on the clean input
|
||||
|
||||
Notes: To ensure consistency with the discovered CREPE offset, we do the following
|
||||
- We pad the input audio to the zero-centered CREPE estimator with 80 zeros
|
||||
- We pad the input audio to our feature computation with 160 zeros to center them
|
||||
"""
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('data', type=str, help='input raw audio data')
|
||||
parser.add_argument('output', type=str, help='output directory')
|
||||
parser.add_argument('--gpu-index', type=int, help='GPU index to use if multiple GPUs',default = 0,required = False)
|
||||
parser.add_argument('--chunk-size-frames', type=int, help='Number of frames to process at a time',default = 100000,required = False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import crepe
|
||||
|
||||
data = np.memmap(args.data, dtype=np.int16,mode = 'r')
|
||||
|
||||
# list_features = []
|
||||
list_cents = []
|
||||
list_confidences = []
|
||||
|
||||
min_period = 32
|
||||
max_period = 256
|
||||
f_ref = 16000/max_period
|
||||
chunk_size_frames = args.chunk_size_frames
|
||||
chunk_size = chunk_size_frames*160
|
||||
|
||||
nb_chunks = (data.shape[0]+79)//chunk_size+1
|
||||
|
||||
output_data = np.zeros((0,2),dtype='float32')
|
||||
|
||||
for i in tqdm.trange(nb_chunks):
|
||||
if i==0:
|
||||
chunk = np.concatenate([np.zeros(80),data[:chunk_size-80]])
|
||||
elif i==nb_chunks-1:
|
||||
chunk = data[i*chunk_size-80:]
|
||||
else:
|
||||
chunk = data[i*chunk_size-80:(i+1)*chunk_size-80]
|
||||
chunk = chunk/np.array(32767.,dtype='float32')
|
||||
|
||||
# Clean Pitch/Confidence Estimate
|
||||
# Padding input to CREPE by 80 samples to ensure it aligns
|
||||
_, pitch, confidence, _ = crepe.predict(chunk, 16000, center=True, viterbi=True,verbose=0)
|
||||
pitch = pitch[:chunk_size_frames]
|
||||
confidence = confidence[:chunk_size_frames]
|
||||
|
||||
|
||||
# Filter out of range pitches/confidences
|
||||
confidence[pitch < 16000/max_period] = 0
|
||||
confidence[pitch > 16000/min_period] = 0
|
||||
pitch = np.reshape(pitch, (-1, 1))
|
||||
confidence = np.reshape(confidence, (-1, 1))
|
||||
out = np.concatenate([pitch, confidence], axis=-1, dtype='float32')
|
||||
output_data = np.concatenate([output_data, out], axis=0)
|
||||
|
||||
|
||||
output_data.tofile(args.output)
|
||||
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Training the neural pitch estimator
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('features', type=str, help='.f32 IF Features for training (generated by augmentation script)')
|
||||
parser.add_argument('features_pitch', type=str, help='.npy Pitch file for training (generated by augmentation script)')
|
||||
parser.add_argument('output_folder', type=str, help='Output directory to store the model weights and config')
|
||||
parser.add_argument('data_format', type=str, help='Choice of Input Data',choices=['if','xcorr','both'])
|
||||
parser.add_argument('--gpu_index', type=int, help='GPU index to use if multiple GPUs',default = 0,required = False)
|
||||
parser.add_argument('--confidence_threshold', type=float, help='Confidence value below which pitch will be neglected during training',default = 0.4,required = False)
|
||||
parser.add_argument('--context', type=int, help='Sequence length during training',default = 100,required = False)
|
||||
parser.add_argument('--N', type=int, help='STFT window size',default = 320,required = False)
|
||||
parser.add_argument('--H', type=int, help='STFT Hop size',default = 160,required = False)
|
||||
parser.add_argument('--xcorr_dimension', type=int, help='Dimension of Input cross-correlation',default = 257,required = False)
|
||||
parser.add_argument('--freq_keep', type=int, help='Number of Frequencies to keep',default = 30,required = False)
|
||||
parser.add_argument('--gru_dim', type=int, help='GRU Dimension',default = 64,required = False)
|
||||
parser.add_argument('--output_dim', type=int, help='Output dimension',default = 192,required = False)
|
||||
parser.add_argument('--learning_rate', type=float, help='Learning Rate',default = 1.0e-3,required = False)
|
||||
parser.add_argument('--epochs', type=int, help='Number of training epochs',default = 50,required = False)
|
||||
parser.add_argument('--choice_cel', type=str, help='Choice of Cross Entropy Loss (default or robust)',choices=['default','robust'],default = 'default',required = False)
|
||||
parser.add_argument('--prefix', type=str, help="prefix for model export, default: model", default='model')
|
||||
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# import os
|
||||
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)
|
||||
|
||||
# Fixing the seeds for reproducability
|
||||
import time
|
||||
np_seed = int(time.time())
|
||||
torch_seed = int(time.time())
|
||||
|
||||
import torch
|
||||
torch.manual_seed(torch_seed)
|
||||
import numpy as np
|
||||
np.random.seed(np_seed)
|
||||
from utils import count_parameters
|
||||
import tqdm
|
||||
from models import PitchDNN, PitchDNNIF, PitchDNNXcorr, PitchDNNDataloader
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
if args.data_format == 'if':
|
||||
pitch_nn = PitchDNNIF(3 * args.freq_keep - 2, args.gru_dim, args.output_dim)
|
||||
elif args.data_format == 'xcorr':
|
||||
pitch_nn = PitchDNNXcorr(args.xcorr_dimension, args.gru_dim, args.output_dim)
|
||||
else:
|
||||
pitch_nn = PitchDNN(3 * args.freq_keep - 2, 224, args.gru_dim, args.output_dim)
|
||||
|
||||
if type(args.initial_checkpoint) != type(None):
|
||||
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||
pitch_nn.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
|
||||
|
||||
dataset_training = PitchDNNDataloader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
|
||||
|
||||
def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
||||
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
|
||||
labels_one_hot = torch.nn.functional.one_hot(labels.long(),nmax)
|
||||
|
||||
if choice == 'default':
|
||||
# Categorical Cross Entropy
|
||||
CE = -torch.sum(torch.log(logits_softmax*labels_one_hot + 1.0e-6)*labels_one_hot,dim=-1)
|
||||
CE = torch.mean(confidence*CE)
|
||||
|
||||
else:
|
||||
# Robust Cross Entropy
|
||||
CE = (1.0/q)*(1 - torch.sum(torch.pow(logits_softmax*labels_one_hot + 1.0e-7,q),dim=-1) )
|
||||
CE = torch.sum(confidence*CE)
|
||||
|
||||
return CE
|
||||
|
||||
def accuracy(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
|
||||
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
|
||||
pred_pitch = torch.argmax(logits_softmax, 2)
|
||||
accuracy = (pred_pitch != labels.long())*1.
|
||||
return 1.-torch.mean(confidence*accuracy)
|
||||
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(dataset_training, [0.95,0.05], generator=torch.Generator().manual_seed(torch_seed))
|
||||
|
||||
batch_size = 256
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
|
||||
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
|
||||
|
||||
pitch_nn = pitch_nn.to(device)
|
||||
num_params = count_parameters(pitch_nn)
|
||||
learning_rate = args.learning_rate
|
||||
model_opt = torch.optim.Adam(pitch_nn.parameters(), lr = learning_rate)
|
||||
|
||||
num_epochs = args.epochs
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
losses = []
|
||||
accs = []
|
||||
pitch_nn.train()
|
||||
with tqdm.tqdm(train_dataloader) as train_epoch:
|
||||
for i, (xi, yi, ci) in enumerate(train_epoch):
|
||||
yi, xi, ci = yi.to(device, non_blocking=True), xi.to(device, non_blocking=True), ci.to(device, non_blocking=True)
|
||||
pi = pitch_nn(xi.float())
|
||||
loss = loss_custom(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
|
||||
acc = accuracy(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
|
||||
acc = acc.detach()
|
||||
|
||||
model_opt.zero_grad()
|
||||
loss.backward()
|
||||
model_opt.step()
|
||||
|
||||
losses.append(loss.item())
|
||||
accs.append(acc.item())
|
||||
avg_loss = np.mean(losses)
|
||||
avg_acc = np.mean(accs)
|
||||
train_epoch.set_postfix({"Train Epoch" : epoch, "Train Loss":avg_loss, "acc" : avg_acc.item()})
|
||||
|
||||
if epoch % 5 == 0:
|
||||
pitch_nn.eval()
|
||||
losses = []
|
||||
with tqdm.tqdm(test_dataloader) as test_epoch:
|
||||
for i, (xi, yi, ci) in enumerate(test_epoch):
|
||||
yi, xi, ci = yi.to(device, non_blocking=True), xi.to(device, non_blocking=True), ci.to(device, non_blocking=True)
|
||||
pi = pitch_nn(xi.float())
|
||||
loss = loss_custom(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
|
||||
losses.append(loss.item())
|
||||
avg_loss = np.mean(losses)
|
||||
test_epoch.set_postfix({"Epoch" : epoch, "Test Loss":avg_loss})
|
||||
|
||||
pitch_nn.eval()
|
||||
|
||||
config = dict(
|
||||
data_format=args.data_format,
|
||||
epochs=num_epochs,
|
||||
window_size= args.N,
|
||||
hop_factor= args.H,
|
||||
freq_keep=args.freq_keep,
|
||||
batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
confidence_threshold=args.confidence_threshold,
|
||||
model_parameters=num_params,
|
||||
np_seed=np_seed,
|
||||
torch_seed=torch_seed,
|
||||
xcorr_dim=args.xcorr_dimension,
|
||||
dim_input=3*args.freq_keep - 2,
|
||||
gru_dim=args.gru_dim,
|
||||
output_dim=args.output_dim,
|
||||
choice_cel=args.choice_cel,
|
||||
context=args.context,
|
||||
)
|
||||
|
||||
model_save_path = os.path.join(args.output_folder, f"{args.prefix}_{args.data_format}.pth")
|
||||
checkpoint = {
|
||||
'state_dict': pitch_nn.state_dict(),
|
||||
'config': config
|
||||
}
|
||||
torch.save(checkpoint, model_save_path)
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Utility functions that are commonly used
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from scipy.signal import windows, lfilter
|
||||
from prettytable import PrettyTable
|
||||
|
||||
|
||||
# Source: https://gist.github.com/thongonary/026210fc186eb5056f2b6f1ca362d912
|
||||
def count_parameters(model):
|
||||
table = PrettyTable(["Modules", "Parameters"])
|
||||
total_params = 0
|
||||
for name, parameter in model.named_parameters():
|
||||
if not parameter.requires_grad: continue
|
||||
param = parameter.numel()
|
||||
table.add_row([name, param])
|
||||
total_params+=param
|
||||
print(table)
|
||||
print(f"Total Trainable Params: {total_params}")
|
||||
return total_params
|
||||
|
||||
def stft(x, w = 'boxcar', N = 320, H = 160):
|
||||
x = np.concatenate([x,np.zeros(N)])
|
||||
# win_custom = np.concatenate([windows.hann(80)[:40],np.ones(240),windows.hann(80)[40:]])
|
||||
return np.stack([np.fft.rfft(x[i:i + N]*windows.get_window(w,N)) for i in np.arange(0,x.shape[0]-N,H)])
|
||||
|
||||
def random_filter(x):
|
||||
# Randomly filter x with second order IIR filter with coefficients in between -3/8,3/8
|
||||
filter_coeff = np.random.uniform(low = -3.0/8, high = 3.0/8, size = 4)
|
||||
b = [1,filter_coeff[0],filter_coeff[1]]
|
||||
a = [1,filter_coeff[2],filter_coeff[3]]
|
||||
return lfilter(b,a,x)
|
||||
|
||||
def feature_xform(feature):
|
||||
"""
|
||||
Take as input the (N * 256) xcorr features output by LPCNet and perform the following
|
||||
1. Downsample and Upsample by 2 (followed by smoothing)
|
||||
2. Append positional embeddings (of dim k) coresponding to each xcorr lag
|
||||
"""
|
||||
|
||||
from scipy.signal import resample_poly, lfilter
|
||||
|
||||
|
||||
feature_US = lfilter([0.25,0.5,0.25],[1],resample_poly(feature,2,1,axis = 1),axis = 1)[:,:feature.shape[1]]
|
||||
feature_DS = lfilter([0.5,0.5],[1],resample_poly(feature,1,2,axis = 1),axis = 1)
|
||||
Z_append = np.zeros((feature.shape[0],feature.shape[1] - feature_DS.shape[1]))
|
||||
feature_DS = np.concatenate([feature_DS,Z_append],axis = -1)
|
||||
|
||||
# pos_embedding = []
|
||||
# for i in range(k):
|
||||
# pos_embedding.append(np.cos((2**i)*np.pi*((np.repeat(np.arange(feature.shape[1]).reshape(feature.shape[1],1),feature.shape[0],axis = 1)).T/(2*feature.shape[1]))))
|
||||
|
||||
# pos_embedding = np.stack(pos_embedding,axis = -1)
|
||||
|
||||
feature = np.stack((feature_DS,feature,feature_US),axis = -1)
|
||||
# feature = np.concatenate((feature,pos_embedding),axis = -1)
|
||||
|
||||
return feature
|
||||
66
managed_components/78__esp-opus/dnn/torch/osce/README.md
Normal file
66
managed_components/78__esp-opus/dnn/torch/osce/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Opus Speech Coding Enhancement
|
||||
|
||||
This folder hosts models for enhancing Opus SILK.
|
||||
|
||||
## Environment setup
|
||||
The code is tested with python 3.11. Conda setup is done via
|
||||
|
||||
|
||||
`conda create -n osce python=3.11`
|
||||
|
||||
`conda activate osce`
|
||||
|
||||
`python -m pip install -r requirements.txt`
|
||||
|
||||
|
||||
## Generating training data
|
||||
First step is to convert all training items to 16 kHz and 16 bit pcm and then concatenate them. A convenient way to do this is to create a file list and then run
|
||||
|
||||
`python scripts/concatenator.py filelist 16000 dataset/clean.s16 --db_min -40 --db_max 0`
|
||||
|
||||
which on top provides some random scaling. Data is taken from the datasets listed in dnn/datasets.txt and the exact list of items used for training and validation is
|
||||
located in dnn/torch/osce/resources.
|
||||
|
||||
Second step is to run a patched version of opus_demo in the dataset folder, which will produce the coded output and add feature files. To build the patched opus_demo binary, check out the exp-neural-silk-enhancement branch and build opus_demo the usual way. Then run
|
||||
|
||||
`cd dataset && <path_to_patched_opus_demo>/opus_demo voip 16000 1 9000 -silk_random_switching 249 clean.s16 coded.s16 `
|
||||
|
||||
The argument to -silk_random_switching specifies the number of frames after which parameters are switched randomly.
|
||||
|
||||
## Regression loss based training
|
||||
Create a default setup for LACE or NoLACE via
|
||||
|
||||
`python make_default_setup.py model.yml --model lace/nolace --path2dataset <path2dataset>`
|
||||
|
||||
Then run
|
||||
|
||||
`python train_model.py model.yml <output folder> --no-redirect`
|
||||
|
||||
for running the training script in foreground or
|
||||
|
||||
`nohup python train_model.py model.yml <output folder> &`
|
||||
|
||||
to run it in background. In the latter case the output is written to `<output folder>/out.txt`.
|
||||
|
||||
## Adversarial training (NoLACE only)
|
||||
Create a default setup for NoLACE via
|
||||
|
||||
`python make_default_setup.py nolace_adv.yml --model nolace --adversarial --path2dataset <path2dataset>`
|
||||
|
||||
Then run
|
||||
|
||||
`python adv_train_model.py nolace_adv.yml <output folder> --no-redirect`
|
||||
|
||||
for running the training script in foreground or
|
||||
|
||||
`nohup python adv_train_model.py nolace_adv.yml <output folder> &`
|
||||
|
||||
to run it in background. In the latter case the output is written to `<output folder>/out.txt`.
|
||||
|
||||
## Inference
|
||||
Generating inference data is analogous to generating training data. Given an item 'item1.wav' run
|
||||
`mkdir item1.se && sox item1.wav -r 16000 -e signed-integer -b 16 item1.raw && cd item1.se && <path_to_patched_opus_demo>/opus_demo voip 16000 1 <bitrate> ../item1.raw noisy.s16`
|
||||
|
||||
The folder item1.se then serves as input for the test_model.py script or for the --testdata argument of train_model.py resp. adv_train_model.py
|
||||
|
||||
autogen.sh downloads pre-trained model weights to the subfolder dnn/models of the main repo.
|
||||
@@ -0,0 +1,462 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
import math as m
|
||||
import random
|
||||
|
||||
import yaml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
import git
|
||||
has_git = True
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scipy.io import wavfile
|
||||
import numpy as np
|
||||
import pesq
|
||||
|
||||
from data import SilkEnhancementSet
|
||||
from models import model_dict
|
||||
|
||||
|
||||
from utils.silk_features import load_inference_data
|
||||
from utils.misc import count_parameters, retain_grads, get_grad_norm, create_weights
|
||||
|
||||
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||
parser.add_argument('output', type=str, help='output path')
|
||||
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||
parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
|
||||
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
checkpoint_prefix = 'checkpoint'
|
||||
output_prefix = 'output'
|
||||
setup_name = 'setup.yml'
|
||||
output_file='out.txt'
|
||||
|
||||
|
||||
# check model
|
||||
if not 'name' in setup['model']:
|
||||
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
|
||||
model_name = 'pitchpostfilter'
|
||||
else:
|
||||
model_name = setup['model']['name']
|
||||
|
||||
# prepare output folder
|
||||
if os.path.exists(args.output):
|
||||
print("warning: output folder exists")
|
||||
|
||||
reply = input('continue? (y/n): ')
|
||||
while reply not in {'y', 'n'}:
|
||||
reply = input('continue? (y/n): ')
|
||||
|
||||
if reply == 'n':
|
||||
os._exit()
|
||||
else:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# add repo info to setup
|
||||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir, search_parent_directories=True)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
is_dirty = repo.is_dirty()
|
||||
|
||||
if is_dirty:
|
||||
print("warning: repo is dirty")
|
||||
|
||||
setup['repo']['hash'] = hash
|
||||
setup['repo']['urls'] = urls
|
||||
setup['repo']['dirty'] = is_dirty
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
# dump setup
|
||||
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||
yaml.dump(setup, f)
|
||||
|
||||
|
||||
ref = None
|
||||
if args.testdata is not None:
|
||||
|
||||
testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
|
||||
|
||||
inference_test = True
|
||||
inference_folder = os.path.join(args.output, 'inference_test')
|
||||
os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
|
||||
|
||||
try:
|
||||
ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
inference_test = False
|
||||
|
||||
# training parameters
|
||||
batch_size = setup['training']['batch_size']
|
||||
epochs = setup['training']['epochs']
|
||||
lr = setup['training']['lr']
|
||||
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||
lr_gen = lr * setup['training']['gen_lr_reduction']
|
||||
lambda_feat = setup['training']['lambda_feat']
|
||||
lambda_reg = setup['training']['lambda_reg']
|
||||
adv_target = setup['training'].get('adv_target', 'target')
|
||||
|
||||
# load training dataset
|
||||
data_config = setup['data']
|
||||
data = SilkEnhancementSet(setup['dataset'], **data_config)
|
||||
|
||||
# load validation dataset if given
|
||||
if 'validation_dataset' in setup:
|
||||
validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
|
||||
|
||||
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
|
||||
|
||||
run_validation = True
|
||||
else:
|
||||
run_validation = False
|
||||
|
||||
# create model
|
||||
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
|
||||
|
||||
# create discriminator
|
||||
disc_name = setup['discriminator']['name']
|
||||
disc = model_dict[disc_name](
|
||||
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
|
||||
)
|
||||
|
||||
# set compute device
|
||||
if type(args.device) == type(None):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
# dataloader
|
||||
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
|
||||
|
||||
# optimizer is introduced to trainable parameters
|
||||
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
|
||||
|
||||
# disc optimizer
|
||||
parameters = [p for p in disc.parameters() if p.requires_grad]
|
||||
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||
|
||||
if args.initial_checkpoint is not None:
|
||||
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||
chkpt = torch.load(args.initial_checkpoint, map_location=device)
|
||||
model.load_state_dict(chkpt['state_dict'])
|
||||
|
||||
if 'disc_state_dict' in chkpt:
|
||||
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
|
||||
disc.load_state_dict(chkpt['disc_state_dict'])
|
||||
|
||||
if 'optimizer_state_dict' in chkpt:
|
||||
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
|
||||
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
|
||||
|
||||
if 'disc_optimizer_state_dict' in chkpt:
|
||||
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
|
||||
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
|
||||
|
||||
if 'scheduler_state_disc' in chkpt:
|
||||
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
|
||||
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
|
||||
|
||||
# if 'torch_rng_state' in chkpt:
|
||||
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
|
||||
# torch.set_rng_state(chkpt['torch_rng_state'])
|
||||
|
||||
if 'numpy_rng_state' in chkpt:
|
||||
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
|
||||
np.random.set_state(chkpt['numpy_rng_state'])
|
||||
|
||||
if 'python_rng_state' in chkpt:
|
||||
print(f"setting Python RNG state from {args.initial_checkpoint}...")
|
||||
random.setstate(chkpt['python_rng_state'])
|
||||
|
||||
# loss
|
||||
w_l1 = setup['training']['loss']['w_l1']
|
||||
w_lm = setup['training']['loss']['w_lm']
|
||||
w_slm = setup['training']['loss']['w_slm']
|
||||
w_sc = setup['training']['loss']['w_sc']
|
||||
w_logmel = setup['training']['loss']['w_logmel']
|
||||
w_wsc = setup['training']['loss']['w_wsc']
|
||||
w_xcorr = setup['training']['loss']['w_xcorr']
|
||||
w_sxcorr = setup['training']['loss']['w_sxcorr']
|
||||
w_l2 = setup['training']['loss']['w_l2']
|
||||
|
||||
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
|
||||
|
||||
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
|
||||
logmelloss = MRLogMelLoss().to(device)
|
||||
|
||||
def xcorr_loss(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return torch.mean(loss)
|
||||
|
||||
def td_l2_norm(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def td_l1(y_true, y_pred, pow=0):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
|
||||
|
||||
return torch.mean(tmp)
|
||||
|
||||
def criterion(x, y):
|
||||
|
||||
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
|
||||
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
|
||||
|
||||
|
||||
# model checkpoint
|
||||
checkpoint = {
|
||||
'setup' : setup,
|
||||
'state_dict' : model.state_dict(),
|
||||
'loss' : -1
|
||||
}
|
||||
|
||||
|
||||
if not args.no_redirect:
|
||||
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||
|
||||
|
||||
print("summary:")
|
||||
|
||||
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
|
||||
if hasattr(model, 'flop_count'):
|
||||
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
|
||||
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
|
||||
|
||||
if ref is not None:
|
||||
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
|
||||
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
|
||||
print(f"initial MOS (PESQ): {initial_mos}")
|
||||
|
||||
best_loss = 1e9
|
||||
log_interval = 10
|
||||
|
||||
|
||||
m_r = 0
|
||||
m_f = 0
|
||||
s_r = 1
|
||||
s_f = 1
|
||||
|
||||
def optimizer_to(optim, device):
|
||||
for param in optim.state.values():
|
||||
if isinstance(param, torch.Tensor):
|
||||
param.data = param.data.to(device)
|
||||
if param._grad is not None:
|
||||
param._grad.data = param._grad.data.to(device)
|
||||
elif isinstance(param, dict):
|
||||
for subparam in param.values():
|
||||
if isinstance(subparam, torch.Tensor):
|
||||
subparam.data = subparam.data.to(device)
|
||||
if subparam._grad is not None:
|
||||
subparam._grad.data = subparam._grad.data.to(device)
|
||||
|
||||
optimizer_to(optimizer, device)
|
||||
optimizer_to(optimizer_disc, device)
|
||||
|
||||
retain_grads(model)
|
||||
retain_grads(disc)
|
||||
|
||||
for ep in range(1, epochs + 1):
|
||||
print(f"training epoch {ep}...")
|
||||
|
||||
model.to(device)
|
||||
disc.to(device)
|
||||
model.train()
|
||||
disc.train()
|
||||
|
||||
running_disc_loss = 0
|
||||
running_adv_loss = 0
|
||||
running_feature_loss = 0
|
||||
running_reg_loss = 0
|
||||
running_disc_grad_norm = 0
|
||||
running_model_grad_norm = 0
|
||||
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# set gradients to zero
|
||||
optimizer.zero_grad()
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target'].to(device)
|
||||
disc_target = batch[adv_target].to(device)
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
|
||||
|
||||
# discriminator update
|
||||
scores_gen = disc(output.detach())
|
||||
scores_real = disc(disc_target.unsqueeze(1))
|
||||
|
||||
disc_loss = 0
|
||||
for score in scores_gen:
|
||||
disc_loss += (((score[-1]) ** 2)).mean()
|
||||
m_f = 0.9 * m_f + 0.1 * score[-1].detach().mean().cpu().item()
|
||||
s_f = 0.9 * s_f + 0.1 * score[-1].detach().std().cpu().item()
|
||||
|
||||
for score in scores_real:
|
||||
disc_loss += (((1 - score[-1]) ** 2)).mean()
|
||||
m_r = 0.9 * m_r + 0.1 * score[-1].detach().mean().cpu().item()
|
||||
s_r = 0.9 * s_r + 0.1 * score[-1].detach().std().cpu().item()
|
||||
|
||||
disc_loss = 0.5 * disc_loss / len(scores_gen)
|
||||
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
|
||||
|
||||
disc.zero_grad()
|
||||
disc_loss.backward()
|
||||
|
||||
running_disc_grad_norm += get_grad_norm(disc).detach().cpu().item()
|
||||
|
||||
optimizer_disc.step()
|
||||
|
||||
# generator update
|
||||
scores_gen = disc(output)
|
||||
|
||||
# calculate loss
|
||||
loss_reg = criterion(output.squeeze(1), target)
|
||||
|
||||
num_discs = len(scores_gen)
|
||||
gen_loss = 0
|
||||
for score in scores_gen:
|
||||
gen_loss += (((1 - score[-1]) ** 2)).mean() / num_discs
|
||||
|
||||
loss_feat = 0
|
||||
for k in range(num_discs):
|
||||
num_layers = len(scores_gen[k]) - 1
|
||||
f = 4 / num_discs / num_layers
|
||||
for l in range(num_layers):
|
||||
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
|
||||
|
||||
model.zero_grad()
|
||||
|
||||
(gen_loss + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
# sparsification
|
||||
if hasattr(model, 'sparsifier'):
|
||||
model.sparsifier()
|
||||
|
||||
running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
|
||||
running_adv_loss += gen_loss.detach().cpu().item()
|
||||
running_disc_loss += disc_loss.detach().cpu().item()
|
||||
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
|
||||
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
|
||||
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
|
||||
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
|
||||
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
|
||||
model_gradnorm=f"{running_model_grad_norm/(i+1):8.7f}",
|
||||
disc_gradnorm=f"{running_disc_grad_norm/(i+1):8.7f}",
|
||||
wc=f"{100*winning_chance:5.2f}%")
|
||||
|
||||
|
||||
# save checkpoint
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['disc_state_dict'] = disc.state_dict()
|
||||
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
||||
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
|
||||
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
|
||||
checkpoint['torch_rng_state'] = torch.get_rng_state()
|
||||
checkpoint['numpy_rng_state'] = np.random.get_state()
|
||||
checkpoint['python_rng_state'] = random.getstate()
|
||||
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
|
||||
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
|
||||
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
|
||||
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
|
||||
|
||||
|
||||
if inference_test:
|
||||
print("running inference test...")
|
||||
out = model.process(testsignal, features, periods, numbits).cpu().numpy()
|
||||
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
|
||||
if ref is not None:
|
||||
mos = pesq.pesq(16000, ref, out, mode='wb')
|
||||
print(f"MOS (PESQ): {mos}")
|
||||
|
||||
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||
|
||||
|
||||
print()
|
||||
|
||||
print('Done')
|
||||
@@ -0,0 +1,451 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
import math as m
|
||||
import random
|
||||
|
||||
import yaml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
import git
|
||||
has_git = True
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scipy.io import wavfile
|
||||
import numpy as np
|
||||
import pesq
|
||||
|
||||
from data import LPCNetVocodingDataset
|
||||
from models import model_dict
|
||||
|
||||
|
||||
from utils.lpcnet_features import load_lpcnet_features
|
||||
from utils.misc import count_parameters
|
||||
|
||||
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||
parser.add_argument('output', type=str, help='output path')
|
||||
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||
parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
|
||||
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
checkpoint_prefix = 'checkpoint'
|
||||
output_prefix = 'output'
|
||||
setup_name = 'setup.yml'
|
||||
output_file='out.txt'
|
||||
|
||||
|
||||
# check model
|
||||
if not 'name' in setup['model']:
|
||||
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
|
||||
model_name = 'pitchpostfilter'
|
||||
else:
|
||||
model_name = setup['model']['name']
|
||||
|
||||
# prepare output folder
|
||||
if os.path.exists(args.output):
|
||||
print("warning: output folder exists")
|
||||
|
||||
reply = input('continue? (y/n): ')
|
||||
while reply not in {'y', 'n'}:
|
||||
reply = input('continue? (y/n): ')
|
||||
|
||||
if reply == 'n':
|
||||
os._exit()
|
||||
else:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# add repo info to setup
|
||||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir, search_parent_directories=True)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
is_dirty = repo.is_dirty()
|
||||
|
||||
if is_dirty:
|
||||
print("warning: repo is dirty")
|
||||
|
||||
setup['repo']['hash'] = hash
|
||||
setup['repo']['urls'] = urls
|
||||
setup['repo']['dirty'] = is_dirty
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
# dump setup
|
||||
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||
yaml.dump(setup, f)
|
||||
|
||||
|
||||
ref = None
|
||||
# prepare inference test if wanted
|
||||
inference_test = False
|
||||
if type(args.test_features) != type(None):
|
||||
test_features = load_lpcnet_features(args.test_features)
|
||||
features = test_features['features']
|
||||
periods = test_features['periods']
|
||||
inference_folder = os.path.join(args.output, 'inference_test')
|
||||
os.makedirs(inference_folder, exist_ok=True)
|
||||
inference_test = True
|
||||
|
||||
|
||||
# training parameters
|
||||
batch_size = setup['training']['batch_size']
|
||||
epochs = setup['training']['epochs']
|
||||
lr = setup['training']['lr']
|
||||
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||
lr_gen = lr * setup['training']['gen_lr_reduction']
|
||||
lambda_feat = setup['training']['lambda_feat']
|
||||
lambda_reg = setup['training']['lambda_reg']
|
||||
adv_target = setup['training'].get('adv_target', 'target')
|
||||
|
||||
|
||||
# load training dataset
|
||||
data_config = setup['data']
|
||||
data = LPCNetVocodingDataset(setup['dataset'], **data_config)
|
||||
|
||||
# load validation dataset if given
|
||||
if 'validation_dataset' in setup:
|
||||
validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
|
||||
|
||||
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
|
||||
|
||||
run_validation = True
|
||||
else:
|
||||
run_validation = False
|
||||
|
||||
# create model
|
||||
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
|
||||
|
||||
|
||||
# create discriminator
|
||||
disc_name = setup['discriminator']['name']
|
||||
disc = model_dict[disc_name](
|
||||
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
|
||||
)
|
||||
|
||||
|
||||
|
||||
# set compute device
|
||||
if type(args.device) == type(None):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
|
||||
|
||||
# dataloader
|
||||
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
|
||||
|
||||
# optimizer is introduced to trainable parameters
|
||||
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
|
||||
|
||||
# disc optimizer
|
||||
parameters = [p for p in disc.parameters() if p.requires_grad]
|
||||
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||
|
||||
if args.initial_checkpoint is not None:
|
||||
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||
chkpt = torch.load(args.initial_checkpoint, map_location=device)
|
||||
model.load_state_dict(chkpt['state_dict'])
|
||||
|
||||
if 'disc_state_dict' in chkpt:
|
||||
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
|
||||
disc.load_state_dict(chkpt['disc_state_dict'])
|
||||
|
||||
if 'optimizer_state_dict' in chkpt:
|
||||
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
|
||||
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
|
||||
|
||||
if 'disc_optimizer_state_dict' in chkpt:
|
||||
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
|
||||
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
|
||||
|
||||
if 'scheduler_state_disc' in chkpt:
|
||||
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
|
||||
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
|
||||
|
||||
# if 'torch_rng_state' in chkpt:
|
||||
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
|
||||
# torch.set_rng_state(chkpt['torch_rng_state'])
|
||||
|
||||
if 'numpy_rng_state' in chkpt:
|
||||
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
|
||||
np.random.set_state(chkpt['numpy_rng_state'])
|
||||
|
||||
if 'python_rng_state' in chkpt:
|
||||
print(f"setting Python RNG state from {args.initial_checkpoint}...")
|
||||
random.setstate(chkpt['python_rng_state'])
|
||||
|
||||
# loss
|
||||
w_l1 = setup['training']['loss']['w_l1']
|
||||
w_lm = setup['training']['loss']['w_lm']
|
||||
w_slm = setup['training']['loss']['w_slm']
|
||||
w_sc = setup['training']['loss']['w_sc']
|
||||
w_logmel = setup['training']['loss']['w_logmel']
|
||||
w_wsc = setup['training']['loss']['w_wsc']
|
||||
w_xcorr = setup['training']['loss']['w_xcorr']
|
||||
w_sxcorr = setup['training']['loss']['w_sxcorr']
|
||||
w_l2 = setup['training']['loss']['w_l2']
|
||||
|
||||
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
|
||||
|
||||
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
|
||||
logmelloss = MRLogMelLoss().to(device)
|
||||
|
||||
def xcorr_loss(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return torch.mean(loss)
|
||||
|
||||
def td_l2_norm(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def td_l1(y_true, y_pred, pow=0):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
|
||||
|
||||
return torch.mean(tmp)
|
||||
|
||||
def criterion(x, y):
|
||||
|
||||
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
|
||||
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
|
||||
|
||||
|
||||
# model checkpoint
|
||||
checkpoint = {
|
||||
'setup' : setup,
|
||||
'state_dict' : model.state_dict(),
|
||||
'loss' : -1
|
||||
}
|
||||
|
||||
|
||||
if not args.no_redirect:
|
||||
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||
|
||||
|
||||
print("summary:")
|
||||
|
||||
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
|
||||
if hasattr(model, 'flop_count'):
|
||||
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
|
||||
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
|
||||
|
||||
if ref is not None:
|
||||
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
|
||||
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
|
||||
print(f"initial MOS (PESQ): {initial_mos}")
|
||||
|
||||
best_loss = 1e9
|
||||
log_interval = 10
|
||||
|
||||
|
||||
m_r = 0
|
||||
m_f = 0
|
||||
s_r = 1
|
||||
s_f = 1
|
||||
|
||||
def optimizer_to(optim, device):
|
||||
for param in optim.state.values():
|
||||
if isinstance(param, torch.Tensor):
|
||||
param.data = param.data.to(device)
|
||||
if param._grad is not None:
|
||||
param._grad.data = param._grad.data.to(device)
|
||||
elif isinstance(param, dict):
|
||||
for subparam in param.values():
|
||||
if isinstance(subparam, torch.Tensor):
|
||||
subparam.data = subparam.data.to(device)
|
||||
if subparam._grad is not None:
|
||||
subparam._grad.data = subparam._grad.data.to(device)
|
||||
|
||||
optimizer_to(optimizer, device)
|
||||
optimizer_to(optimizer_disc, device)
|
||||
|
||||
|
||||
for ep in range(1, epochs + 1):
|
||||
print(f"training epoch {ep}...")
|
||||
|
||||
model.to(device)
|
||||
disc.to(device)
|
||||
model.train()
|
||||
disc.train()
|
||||
|
||||
running_disc_loss = 0
|
||||
running_adv_loss = 0
|
||||
running_feature_loss = 0
|
||||
running_reg_loss = 0
|
||||
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# set gradients to zero
|
||||
optimizer.zero_grad()
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target'].to(device)
|
||||
disc_target = batch[adv_target].to(device)
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['features'], batch['periods'])
|
||||
|
||||
# discriminator update
|
||||
scores_gen = disc(output.detach())
|
||||
scores_real = disc(disc_target.unsqueeze(1))
|
||||
|
||||
disc_loss = 0
|
||||
for scale in scores_gen:
|
||||
disc_loss += ((scale[-1]) ** 2).mean()
|
||||
m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item()
|
||||
s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item()
|
||||
|
||||
for scale in scores_real:
|
||||
disc_loss += ((1 - scale[-1]) ** 2).mean()
|
||||
m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item()
|
||||
s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item()
|
||||
|
||||
disc_loss = 0.5 * disc_loss / len(scores_gen)
|
||||
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
|
||||
|
||||
disc.zero_grad()
|
||||
disc_loss.backward()
|
||||
optimizer_disc.step()
|
||||
|
||||
# generator update
|
||||
scores_gen = disc(output)
|
||||
|
||||
|
||||
# calculate loss
|
||||
loss_reg = criterion(output.squeeze(1), target)
|
||||
|
||||
num_discs = len(scores_gen)
|
||||
loss_gen = 0
|
||||
for scale in scores_gen:
|
||||
loss_gen += ((1 - scale[-1]) ** 2).mean() / num_discs
|
||||
|
||||
loss_feat = 0
|
||||
for k in range(num_discs):
|
||||
num_layers = len(scores_gen[k]) - 1
|
||||
f = 4 / num_discs / num_layers
|
||||
for l in range(num_layers):
|
||||
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
|
||||
|
||||
model.zero_grad()
|
||||
|
||||
(loss_gen + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
running_adv_loss += loss_gen.detach().cpu().item()
|
||||
running_disc_loss += disc_loss.detach().cpu().item()
|
||||
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
|
||||
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
|
||||
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
|
||||
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
|
||||
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
|
||||
wc=f"{100*winning_chance:5.2f}%")
|
||||
|
||||
|
||||
# save checkpoint
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['disc_state_dict'] = disc.state_dict()
|
||||
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
||||
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
|
||||
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
|
||||
checkpoint['torch_rng_state'] = torch.get_rng_state()
|
||||
checkpoint['numpy_rng_state'] = np.random.get_state()
|
||||
checkpoint['python_rng_state'] = random.getstate()
|
||||
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
|
||||
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
|
||||
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
|
||||
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
|
||||
|
||||
|
||||
if inference_test:
|
||||
print("running inference test...")
|
||||
out = model.process(features, periods).cpu().numpy()
|
||||
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
|
||||
if ref is not None:
|
||||
mos = pesq.pesq(16000, ref, out, mode='wb')
|
||||
print(f"MOS (PESQ): {mos}")
|
||||
|
||||
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||
|
||||
|
||||
print()
|
||||
|
||||
print('Done')
|
||||
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from models import model_dict
|
||||
from utils import endoscopy
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('checkpoint_path', type=str, help='path to folder containing checkpoints "lace_checkpoint.pth" and nolace_checkpoint.pth"')
|
||||
parser.add_argument('output_folder', type=str, help='output folder for testvectors')
|
||||
parser.add_argument('--debug', action='store_true', help='add debug output to output folder')
|
||||
|
||||
|
||||
def create_adaconv_testvector(prefix, adaconv, num_frames, debug=False):
|
||||
feature_dim = adaconv.feature_dim
|
||||
in_channels = adaconv.in_channels
|
||||
out_channels = adaconv.out_channels
|
||||
frame_size = adaconv.frame_size
|
||||
|
||||
features = torch.randn((1, num_frames, feature_dim))
|
||||
x_in = torch.randn((1, in_channels, num_frames * frame_size))
|
||||
|
||||
x_out = adaconv(x_in, features, debug=debug)
|
||||
|
||||
features = features[0].detach().numpy()
|
||||
x_in = x_in[0].reshape(in_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
|
||||
x_out = x_out[0].reshape(out_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
|
||||
|
||||
features.tofile(prefix + '_features.f32')
|
||||
x_in.tofile(prefix + '_x_in.f32')
|
||||
x_out.tofile(prefix + '_x_out.f32')
|
||||
|
||||
def create_adacomb_testvector(prefix, adacomb, num_frames, debug=False):
|
||||
feature_dim = adacomb.feature_dim
|
||||
in_channels = 1
|
||||
frame_size = adacomb.frame_size
|
||||
|
||||
features = torch.randn((1, num_frames, feature_dim))
|
||||
x_in = torch.randn((1, in_channels, num_frames * frame_size))
|
||||
p_in = torch.randint(adacomb.kernel_size, 250, (1, num_frames))
|
||||
|
||||
x_out = adacomb(x_in, features, p_in, debug=debug)
|
||||
|
||||
features = features[0].detach().numpy()
|
||||
x_in = x_in[0].permute(1, 0).detach().numpy()
|
||||
p_in = p_in[0].detach().numpy().astype(np.int32)
|
||||
x_out = x_out[0].permute(1, 0).detach().numpy()
|
||||
|
||||
features.tofile(prefix + '_features.f32')
|
||||
x_in.tofile(prefix + '_x_in.f32')
|
||||
p_in.tofile(prefix + '_p_in.s32')
|
||||
x_out.tofile(prefix + '_x_out.f32')
|
||||
|
||||
def create_adashape_testvector(prefix, adashape, num_frames):
|
||||
feature_dim = adashape.feature_dim
|
||||
frame_size = adashape.frame_size
|
||||
|
||||
features = torch.randn((1, num_frames, feature_dim))
|
||||
x_in = torch.randn((1, 1, num_frames * frame_size))
|
||||
|
||||
x_out = adashape(x_in, features)
|
||||
|
||||
features = features[0].detach().numpy()
|
||||
x_in = x_in.flatten().detach().numpy()
|
||||
x_out = x_out.flatten().detach().numpy()
|
||||
|
||||
features.tofile(prefix + '_features.f32')
|
||||
x_in.tofile(prefix + '_x_in.f32')
|
||||
x_out.tofile(prefix + '_x_out.f32')
|
||||
|
||||
def create_feature_net_testvector(prefix, model, num_frames):
|
||||
num_features = model.num_features
|
||||
num_subframes = 4 * num_frames
|
||||
|
||||
input_features = torch.randn((1, num_subframes, num_features))
|
||||
periods = torch.randint(32, 300, (1, num_subframes))
|
||||
numbits = model.numbits_range[0] + torch.rand((1, num_frames, 2)) * (model.numbits_range[1] - model.numbits_range[0])
|
||||
|
||||
|
||||
pembed = model.pitch_embedding(periods)
|
||||
nembed = torch.repeat_interleave(model.numbits_embedding(numbits).flatten(2), 4, dim=1)
|
||||
full_features = torch.cat((input_features, pembed, nembed), dim=-1)
|
||||
|
||||
cf = model.feature_net(full_features)
|
||||
|
||||
input_features.float().numpy().tofile(prefix + "_in_features.f32")
|
||||
periods.numpy().astype(np.int32).tofile(prefix + "_periods.s32")
|
||||
numbits.float().numpy().tofile(prefix + "_numbits.f32")
|
||||
full_features.detach().numpy().tofile(prefix + "_full_features.f32")
|
||||
cf.detach().numpy().tofile(prefix + "_out_features.f32")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_folder, exist_ok=True)
|
||||
|
||||
lace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "lace_checkpoint.pth"), map_location='cpu')
|
||||
nolace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "nolace_checkpoint.pth"), map_location='cpu')
|
||||
|
||||
lace = model_dict['lace'](**lace_checkpoint['setup']['model']['kwargs'])
|
||||
nolace = model_dict['nolace'](**nolace_checkpoint['setup']['model']['kwargs'])
|
||||
|
||||
lace.load_state_dict(lace_checkpoint['state_dict'])
|
||||
nolace.load_state_dict(nolace_checkpoint['state_dict'])
|
||||
|
||||
if args.debug:
|
||||
endoscopy.init(args.output_folder)
|
||||
|
||||
# lace af1, 1 input channel, 1 output channel
|
||||
create_adaconv_testvector(os.path.join(args.output_folder, "lace_af1"), lace.af1, 5, debug=args.debug)
|
||||
|
||||
# nolace af1, 1 input channel, 2 output channels
|
||||
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af1"), nolace.af1, 5, debug=args.debug)
|
||||
|
||||
# nolace af4, 2 input channel, 1 output channels
|
||||
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af4"), nolace.af4, 5, debug=args.debug)
|
||||
|
||||
# nolace af2, 2 input channel, 2 output channels
|
||||
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af2"), nolace.af2, 5, debug=args.debug)
|
||||
|
||||
# lace cf1
|
||||
create_adacomb_testvector(os.path.join(args.output_folder, "lace_cf1"), lace.cf1, 5, debug=args.debug)
|
||||
|
||||
# nolace tdshape1
|
||||
create_adashape_testvector(os.path.join(args.output_folder, "nolace_tdshape1"), nolace.tdshape1, 5)
|
||||
|
||||
# lace feature net
|
||||
create_feature_net_testvector(os.path.join(args.output_folder, 'lace'), lace, 5)
|
||||
|
||||
if args.debug:
|
||||
endoscopy.close()
|
||||
@@ -0,0 +1,2 @@
|
||||
from .silk_enhancement_set import SilkEnhancementSet
|
||||
from .lpcnet_vocoding_dataset import LPCNetVocodingDataset
|
||||
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" Dataset for LPCNet training """
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
scale = 255.0/32768.0
|
||||
scale_1 = 32768.0/255.0
|
||||
def ulaw2lin(u):
|
||||
u = u - 128
|
||||
s = np.sign(u)
|
||||
u = np.abs(u)
|
||||
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
|
||||
|
||||
|
||||
def lin2ulaw(x):
|
||||
s = np.sign(x)
|
||||
x = np.abs(x)
|
||||
u = (s*(128*np.log(1+scale*x)/np.log(256)))
|
||||
u = np.clip(128 + np.round(u), 0, 255)
|
||||
return u
|
||||
|
||||
|
||||
def run_lpc(signal, lpcs, frame_length=160):
|
||||
num_frames, lpc_order = lpcs.shape
|
||||
|
||||
prediction = np.concatenate(
|
||||
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
|
||||
)
|
||||
error = signal[lpc_order :] - prediction
|
||||
|
||||
return prediction, error
|
||||
|
||||
class LPCNetVocodingDataset(Dataset):
|
||||
def __init__(self,
|
||||
path_to_dataset,
|
||||
features=['cepstrum', 'periods', 'pitch_corr'],
|
||||
target='signal',
|
||||
frames_per_sample=100,
|
||||
feature_history=0,
|
||||
feature_lookahead=0,
|
||||
lpc_gamma=1):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# load dataset info
|
||||
self.path_to_dataset = path_to_dataset
|
||||
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
|
||||
dataset = yaml.load(f, yaml.FullLoader)
|
||||
|
||||
# dataset version
|
||||
self.version = dataset['version']
|
||||
if self.version == 1:
|
||||
self.getitem = self.getitem_v1
|
||||
elif self.version == 2:
|
||||
self.getitem = self.getitem_v2
|
||||
else:
|
||||
raise ValueError(f"dataset version {self.version} unknown")
|
||||
|
||||
# features
|
||||
self.feature_history = feature_history
|
||||
self.feature_lookahead = feature_lookahead
|
||||
self.frame_offset = 2 + self.feature_history
|
||||
self.frames_per_sample = frames_per_sample
|
||||
self.input_features = features
|
||||
self.feature_frame_layout = dataset['feature_frame_layout']
|
||||
self.lpc_gamma = lpc_gamma
|
||||
|
||||
# load feature file
|
||||
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
|
||||
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
|
||||
self.feature_frame_length = dataset['feature_frame_length']
|
||||
|
||||
assert len(self.features) % self.feature_frame_length == 0
|
||||
self.features = self.features.reshape((-1, self.feature_frame_length))
|
||||
|
||||
# derive number of samples is dataset
|
||||
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample
|
||||
|
||||
# signals
|
||||
self.frame_length = dataset['frame_length']
|
||||
self.signal_frame_layout = dataset['signal_frame_layout']
|
||||
self.target = target
|
||||
|
||||
# load signals
|
||||
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
|
||||
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
|
||||
self.signal_frame_length = dataset['signal_frame_length']
|
||||
self.signals = self.signals.reshape((-1, self.signal_frame_length))
|
||||
assert len(self.signals) == len(self.features) * self.frame_length
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
|
||||
def getitem_v2(self, index):
|
||||
sample = dict()
|
||||
|
||||
# extract features
|
||||
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||
|
||||
for feature in self.input_features:
|
||||
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||
|
||||
# convert periods
|
||||
if 'periods' in self.input_features:
|
||||
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||
|
||||
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||
|
||||
# last_signal and signal are always expected to be there
|
||||
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
|
||||
|
||||
# calculate prediction and error if lpc coefficients present and prediction not given
|
||||
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
|
||||
# lpc coefficients with one frame lookahead
|
||||
# frame positions (start one frame early for past excitation)
|
||||
frame_start = self.frame_offset + self.frames_per_sample * index - 1
|
||||
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
|
||||
|
||||
# feature positions
|
||||
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
|
||||
lpc_order = lpc_stop - lpc_start
|
||||
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
|
||||
|
||||
# LPC weighting
|
||||
lpc_order = lpc_stop - lpc_start
|
||||
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
|
||||
lpcs = lpcs * weights
|
||||
|
||||
# signal position (lpc_order samples as history)
|
||||
signal_start = frame_start * self.frame_length - lpc_order + 1
|
||||
signal_stop = frame_stop * self.frame_length + 1
|
||||
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
|
||||
|
||||
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
|
||||
|
||||
# extract signals
|
||||
offset = self.frame_length
|
||||
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
|
||||
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
|
||||
# calculate error between real signal and noisy prediction
|
||||
|
||||
|
||||
sample['error'] = sample['signal'] - sample['prediction']
|
||||
|
||||
|
||||
# concatenate features
|
||||
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||
target = torch.FloatTensor(sample[self.target]) / 2**15
|
||||
periods = torch.LongTensor(sample['periods'])
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'target' : target}
|
||||
|
||||
def getitem_v1(self, index):
|
||||
sample = dict()
|
||||
|
||||
# extract features
|
||||
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||
|
||||
for feature in self.input_features:
|
||||
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||
|
||||
# convert periods
|
||||
if 'periods' in self.input_features:
|
||||
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||
|
||||
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||
|
||||
# last_signal and signal are always expected to be there
|
||||
for signal_name, index in self.signal_frame_layout.items():
|
||||
sample[signal_name] = self.signals[signal_start : signal_stop, index]
|
||||
|
||||
# concatenate features
|
||||
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
|
||||
target = torch.LongTensor(sample[self.target])
|
||||
periods = torch.LongTensor(sample['periods'])
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_length
|
||||
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
|
||||
from utils.silk_features import silk_feature_factory
|
||||
from utils.pitch import hangover, calculate_acorr_window
|
||||
|
||||
|
||||
class SilkEnhancementSet(Dataset):
|
||||
def __init__(self,
|
||||
path,
|
||||
frames_per_sample=100,
|
||||
no_pitch_value=9,
|
||||
acorr_radius=2,
|
||||
pitch_hangover=8,
|
||||
num_bands_clean_spec=64,
|
||||
num_bands_noisy_spec=18,
|
||||
noisy_spec_scale='opus',
|
||||
noisy_apply_dct=True,
|
||||
add_offset=False,
|
||||
add_double_lag_acorr=False
|
||||
):
|
||||
|
||||
assert frames_per_sample % 4 == 0
|
||||
|
||||
self.frame_size = 80
|
||||
self.frames_per_sample = frames_per_sample
|
||||
self.no_pitch_value = no_pitch_value
|
||||
self.acorr_radius = acorr_radius
|
||||
self.pitch_hangover = pitch_hangover
|
||||
self.num_bands_clean_spec = num_bands_clean_spec
|
||||
self.num_bands_noisy_spec = num_bands_noisy_spec
|
||||
self.noisy_spec_scale = noisy_spec_scale
|
||||
self.add_double_lag_acorr = add_double_lag_acorr
|
||||
|
||||
self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
|
||||
self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
|
||||
self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
|
||||
self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
|
||||
self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
|
||||
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
|
||||
self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
|
||||
self.lpcnet_features = np.from_file(os.path.join(path, 'features_lpcnet.f32'), dtype=np.float32).reshape(-1, 36)
|
||||
|
||||
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
|
||||
|
||||
self.create_features = silk_feature_factory(no_pitch_value,
|
||||
acorr_radius,
|
||||
pitch_hangover,
|
||||
num_bands_clean_spec,
|
||||
num_bands_noisy_spec,
|
||||
noisy_spec_scale,
|
||||
noisy_apply_dct,
|
||||
add_offset,
|
||||
add_double_lag_acorr)
|
||||
|
||||
self.history_len = 700 if add_double_lag_acorr else 350
|
||||
# discard some frames to have enough signal history
|
||||
self.skip_frames = 4 * ((self.history_len + 319) // 320 + 2)
|
||||
|
||||
num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames
|
||||
|
||||
self.len = num_frames // frames_per_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
frame_start = self.frames_per_sample * index + self.skip_frames
|
||||
frame_stop = frame_start + self.frames_per_sample
|
||||
|
||||
signal_start = frame_start * self.frame_size - self.skip
|
||||
signal_stop = frame_stop * self.frame_size - self.skip
|
||||
|
||||
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||
|
||||
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
|
||||
|
||||
features, periods = self.create_features(
|
||||
coded_signal,
|
||||
coded_signal_history,
|
||||
self.lpcs[frame_start : frame_stop],
|
||||
self.gains[frame_start : frame_stop],
|
||||
self.ltps[frame_start : frame_stop],
|
||||
self.periods[frame_start : frame_stop],
|
||||
self.offsets[frame_start : frame_stop]
|
||||
)
|
||||
|
||||
lpcnet_features = self.lpcnet_features[frame_start // 2 : frame_stop // 2, :20]
|
||||
|
||||
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
||||
num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
||||
|
||||
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
|
||||
|
||||
return {
|
||||
'silk_features' : features,
|
||||
'periods' : periods.astype(np.int64),
|
||||
'numbits' : numbits.astype(np.float32),
|
||||
'lpcnet_features' : lpcnet_features
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
|
||||
from utils.silk_features import silk_feature_factory
|
||||
from utils.pitch import hangover, calculate_acorr_window
|
||||
|
||||
|
||||
class SilkEnhancementSet(Dataset):
|
||||
def __init__(self,
|
||||
path,
|
||||
frames_per_sample=100,
|
||||
no_pitch_value=256,
|
||||
preemph=0.85,
|
||||
skip=91,
|
||||
acorr_radius=2,
|
||||
pitch_hangover=8,
|
||||
num_bands_clean_spec=64,
|
||||
num_bands_noisy_spec=18,
|
||||
noisy_spec_scale='opus',
|
||||
noisy_apply_dct=True,
|
||||
add_double_lag_acorr=False,
|
||||
):
|
||||
|
||||
assert frames_per_sample % 4 == 0
|
||||
|
||||
self.frame_size = 80
|
||||
self.frames_per_sample = frames_per_sample
|
||||
self.no_pitch_value = no_pitch_value
|
||||
self.preemph = preemph
|
||||
self.skip = skip
|
||||
self.acorr_radius = acorr_radius
|
||||
self.pitch_hangover = pitch_hangover
|
||||
self.num_bands_clean_spec = num_bands_clean_spec
|
||||
self.num_bands_noisy_spec = num_bands_noisy_spec
|
||||
self.noisy_spec_scale = noisy_spec_scale
|
||||
self.add_double_lag_acorr = add_double_lag_acorr
|
||||
|
||||
self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
|
||||
self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
|
||||
self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
|
||||
self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
|
||||
self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
|
||||
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
|
||||
|
||||
self.clean_signal_hp = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
|
||||
self.clean_signal = np.fromfile(os.path.join(path, 'clean.s16'), dtype=np.int16)
|
||||
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
|
||||
|
||||
self.create_features = silk_feature_factory(no_pitch_value,
|
||||
acorr_radius,
|
||||
pitch_hangover,
|
||||
num_bands_clean_spec,
|
||||
num_bands_noisy_spec,
|
||||
noisy_spec_scale,
|
||||
noisy_apply_dct,
|
||||
add_double_lag_acorr)
|
||||
|
||||
self.history_len = 700 if add_double_lag_acorr else 350
|
||||
# discard some frames to have enough signal history
|
||||
self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2)
|
||||
|
||||
num_frames = self.clean_signal_hp.shape[0] // 80 - self.skip_frames
|
||||
|
||||
self.len = num_frames // frames_per_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
frame_start = self.frames_per_sample * index + self.skip_frames
|
||||
frame_stop = frame_start + self.frames_per_sample
|
||||
|
||||
signal_start = frame_start * self.frame_size - self.skip
|
||||
signal_stop = frame_stop * self.frame_size - self.skip
|
||||
|
||||
clean_signal_hp = self.clean_signal_hp[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||
clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||
|
||||
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
|
||||
|
||||
features, periods = self.create_features(
|
||||
coded_signal,
|
||||
coded_signal_history,
|
||||
self.lpcs[frame_start : frame_stop],
|
||||
self.gains[frame_start : frame_stop],
|
||||
self.ltps[frame_start : frame_stop],
|
||||
self.periods[frame_start : frame_stop]
|
||||
)
|
||||
|
||||
if self.preemph > 0:
|
||||
clean_signal[1:] -= self.preemph * clean_signal[: -1]
|
||||
clean_signal_hp[1:] -= self.preemph * clean_signal_hp[: -1]
|
||||
coded_signal[1:] -= self.preemph * coded_signal[: -1]
|
||||
|
||||
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
||||
num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
||||
|
||||
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
|
||||
|
||||
return {
|
||||
'features' : features,
|
||||
'periods' : periods.astype(np.int64),
|
||||
'target_orig' : clean_signal.astype(np.float32),
|
||||
'target' : clean_signal_hp.astype(np.float32),
|
||||
'signals' : coded_signal.reshape(-1, 1).astype(np.float32),
|
||||
'numbits' : numbits.astype(np.float32)
|
||||
}
|
||||
103
managed_components/78__esp-opus/dnn/torch/osce/engine/engine.py
Normal file
103
managed_components/78__esp-opus/dnn/torch/osce/engine/engine.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
|
||||
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# set gradients to zero
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
|
||||
|
||||
# calculate loss
|
||||
if isinstance(output, list):
|
||||
loss = torch.zeros(1, device=device)
|
||||
for y in output:
|
||||
loss = loss + criterion(target, y.squeeze(1))
|
||||
loss = loss / len(output)
|
||||
else:
|
||||
loss = criterion(target, output.squeeze(1))
|
||||
|
||||
# calculate gradients
|
||||
loss.backward()
|
||||
|
||||
# update weights
|
||||
optimizer.step()
|
||||
|
||||
# update learning rate
|
||||
scheduler.step()
|
||||
|
||||
# sparsification
|
||||
if hasattr(model, 'sparsifier'):
|
||||
model.sparsifier()
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
|
||||
def evaluate(model, criterion, dataloader, device, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
|
||||
|
||||
# calculate loss
|
||||
loss = criterion(target, output.squeeze(1))
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
@@ -0,0 +1,101 @@
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
|
||||
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# set gradients to zero
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['features'], batch['periods'])
|
||||
|
||||
# calculate loss
|
||||
if isinstance(output, list):
|
||||
loss = torch.zeros(1, device=device)
|
||||
for y in output:
|
||||
loss = loss + criterion(target, y.squeeze(1))
|
||||
loss = loss / len(output)
|
||||
else:
|
||||
loss = criterion(target, output.squeeze(1))
|
||||
|
||||
# calculate gradients
|
||||
loss.backward()
|
||||
|
||||
# update weights
|
||||
optimizer.step()
|
||||
|
||||
# update learning rate
|
||||
scheduler.step()
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
|
||||
def evaluate(model, criterion, dataloader, device, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['features'], batch['periods'])
|
||||
|
||||
# calculate loss
|
||||
loss = criterion(target, output.squeeze(1))
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import hashlib
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
|
||||
|
||||
import torch
|
||||
import wexchange.torch
|
||||
from wexchange.torch import dump_torch_weights
|
||||
from models import model_dict
|
||||
|
||||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
from utils.layers.td_shaper import TDShaper
|
||||
from utils.misc import remove_all_weight_norm
|
||||
from wexchange.torch import dump_torch_weights
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint')
|
||||
parser.add_argument('output_dir', type=str, help='output folder')
|
||||
parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
|
||||
|
||||
sparse_default=False
|
||||
schedules = {
|
||||
'nolace': [
|
||||
('pitch_embedding', dict()),
|
||||
('feature_net.conv1', dict()),
|
||||
('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
|
||||
('cf1', dict(quantize=True, scale=None)),
|
||||
('cf2', dict(quantize=True, scale=None)),
|
||||
('af1', dict(quantize=True, scale=None)),
|
||||
('tdshape1', dict(quantize=True, scale=None)),
|
||||
('tdshape2', dict(quantize=True, scale=None)),
|
||||
('tdshape3', dict(quantize=True, scale=None)),
|
||||
('af2', dict(quantize=True, scale=None)),
|
||||
('af3', dict(quantize=True, scale=None)),
|
||||
('af4', dict(quantize=True, scale=None)),
|
||||
('post_cf1', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('post_cf2', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('post_af1', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('post_af2', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('post_af3', dict(quantize=True, scale=None, sparse=sparse_default))
|
||||
],
|
||||
'lace' : [
|
||||
('pitch_embedding', dict()),
|
||||
('feature_net.conv1', dict()),
|
||||
('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
|
||||
('cf1', dict(quantize=True, scale=None)),
|
||||
('cf2', dict(quantize=True, scale=None)),
|
||||
('af1', dict(quantize=True, scale=None))
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# auxiliary functions
|
||||
def sha1(filename):
|
||||
BUF_SIZE = 65536
|
||||
sha1 = hashlib.sha1()
|
||||
|
||||
with open(filename, 'rb') as f:
|
||||
while True:
|
||||
data = f.read(BUF_SIZE)
|
||||
if not data:
|
||||
break
|
||||
sha1.update(data)
|
||||
|
||||
return sha1.hexdigest()
|
||||
|
||||
def osce_dump_generic(writer, name, module):
|
||||
if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv1d) \
|
||||
or isinstance(module, torch.nn.ConvTranspose1d) or isinstance(module, torch.nn.Embedding) \
|
||||
or isinstance(module, LimitedAdaptiveConv1d) or isinstance(module, LimitedAdaptiveComb1d) \
|
||||
or isinstance(module, TDShaper) or isinstance(module, torch.nn.GRU):
|
||||
dump_torch_weights(writer, module, name=name, verbose=True)
|
||||
else:
|
||||
for child_name, child in module.named_children():
|
||||
osce_dump_generic(writer, (name + "_" + child_name).replace("feature_net", "fnet"), child)
|
||||
|
||||
|
||||
def export_name(name):
|
||||
name = name.replace('.', '_')
|
||||
name = name.replace('feature_net', 'fnet')
|
||||
return name
|
||||
|
||||
def osce_scheduled_dump(writer, prefix, model, schedule):
|
||||
if not prefix.endswith('_'):
|
||||
prefix += '_'
|
||||
|
||||
for name, kwargs in schedule:
|
||||
dump_torch_weights(writer, model.get_submodule(name), prefix + export_name(name), **kwargs, verbose=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint_path = args.checkpoint
|
||||
outdir = args.output_dir
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
# dump message
|
||||
message = f"Auto generated from checkpoint {os.path.basename(checkpoint_path)} (sha1: {sha1(checkpoint_path)})"
|
||||
|
||||
# create model and load weights
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
remove_all_weight_norm(model, verbose=True)
|
||||
|
||||
# CWriter
|
||||
model_name = checkpoint['setup']['model']['name']
|
||||
cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper() + 'Layers', add_typedef=True)
|
||||
|
||||
# Add custom includes and global parameters
|
||||
cwriter.header.write(f'''
|
||||
#define {model_name.upper()}_PREEMPH {model.preemph}f
|
||||
#define {model_name.upper()}_FRAME_SIZE {model.FRAME_SIZE}
|
||||
#define {model_name.upper()}_OVERLAP_SIZE 40
|
||||
#define {model_name.upper()}_NUM_FEATURES {model.num_features}
|
||||
#define {model_name.upper()}_PITCH_MAX {model.pitch_max}
|
||||
#define {model_name.upper()}_PITCH_EMBEDDING_DIM {model.pitch_embedding_dim}
|
||||
#define {model_name.upper()}_NUMBITS_RANGE_LOW {model.numbits_range[0]}
|
||||
#define {model_name.upper()}_NUMBITS_RANGE_HIGH {model.numbits_range[1]}
|
||||
#define {model_name.upper()}_NUMBITS_EMBEDDING_DIM {model.numbits_embedding_dim}
|
||||
#define {model_name.upper()}_COND_DIM {model.cond_dim}
|
||||
#define {model_name.upper()}_HIDDEN_FEATURE_DIM {model.hidden_feature_dim}
|
||||
''')
|
||||
|
||||
for i, s in enumerate(model.numbits_embedding.scale_factors):
|
||||
cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}f\n")
|
||||
|
||||
# dump layers
|
||||
if model_name in schedules and args.quantize:
|
||||
osce_scheduled_dump(cwriter, model_name, model, schedules[model_name])
|
||||
else:
|
||||
osce_dump_generic(cwriter, model_name, model)
|
||||
|
||||
cwriter.close()
|
||||
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
"""STFT-based Loss modules."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
|
||||
|
||||
def get_window(win_name, win_length, *args, **kwargs):
|
||||
window_dict = {
|
||||
'bartlett_window' : torch.bartlett_window,
|
||||
'blackman_window' : torch.blackman_window,
|
||||
'hamming_window' : torch.hamming_window,
|
||||
'hann_window' : torch.hann_window,
|
||||
'kaiser_window' : torch.kaiser_window
|
||||
}
|
||||
|
||||
if not win_name in window_dict:
|
||||
raise ValueError()
|
||||
|
||||
return window_dict[win_name](win_length, *args, **kwargs)
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
"""
|
||||
|
||||
win = get_window(window, win_length).to(x.device)
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
|
||||
|
||||
|
||||
return torch.clamp(torch.abs(x_stft), min=1e-7)
|
||||
|
||||
def spectral_convergence_loss(Y_true, Y_pred):
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
|
||||
|
||||
|
||||
def log_magnitude_loss(Y_true, Y_pred):
|
||||
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
|
||||
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
|
||||
|
||||
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
|
||||
|
||||
def spectral_xcorr_loss(Y_true, Y_pred):
|
||||
Y_true = Y_true.abs()
|
||||
Y_pred = Y_pred.abs()
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return 1 - xcorr.mean()
|
||||
|
||||
|
||||
|
||||
class MRLogMelLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
fs=16000,
|
||||
n_mels=18
|
||||
):
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.mel_specs = []
|
||||
for fft_size in fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
|
||||
n_mels = self.n_mels
|
||||
if fft_size < 128:
|
||||
n_mels //= 2
|
||||
|
||||
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
|
||||
|
||||
for i, mel_spec in enumerate(self.mel_specs):
|
||||
self.add_module(f'mel_spec_{i+1}', mel_spec)
|
||||
|
||||
def forward(self, y_true, y_pred):
|
||||
|
||||
loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for mel_spec in self.mel_specs:
|
||||
Y_true = mel_spec(y_true)
|
||||
Y_pred = mel_spec(y_pred)
|
||||
loss = loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
loss = loss / len(self.mel_specs)
|
||||
|
||||
return loss
|
||||
|
||||
def create_weight_matrix(num_bins, bins_per_band=10):
|
||||
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
|
||||
|
||||
r0 = bins_per_band // 2
|
||||
r1 = bins_per_band - r0
|
||||
|
||||
for i in range(num_bins):
|
||||
i0 = max(i - r0, 0)
|
||||
j0 = min(i + r1, num_bins)
|
||||
|
||||
m[i, i0: j0] += 1
|
||||
|
||||
if i < r0:
|
||||
m[i, :r0 - i] += 1
|
||||
|
||||
if i > num_bins - r1:
|
||||
m[i, num_bins - r1 - i:] += 1
|
||||
|
||||
return m / bins_per_band
|
||||
|
||||
def weighted_spectral_convergence(Y_true, Y_pred, w):
|
||||
|
||||
# calculate sfm based weights
|
||||
logY = torch.log(torch.abs(Y_true) + 1e-9)
|
||||
Y = torch.abs(Y_true)
|
||||
|
||||
avg_logY = torch.matmul(logY.transpose(1, 2), w)
|
||||
avg_Y = torch.matmul(Y.transpose(1, 2), w)
|
||||
|
||||
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
|
||||
|
||||
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
|
||||
|
||||
loss = torch.mean(
|
||||
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
|
||||
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def gen_filterbank(N, Fs=16000):
|
||||
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
|
||||
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
|
||||
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
|
||||
ERB_N = 24.7 + .108*in_freq
|
||||
delta = np.abs(in_freq-out_freq)/ERB_N
|
||||
center = (delta<.5).astype('float32')
|
||||
R = -12*center*delta**2 + (1-center)*(3-12*delta)
|
||||
RE = 10.**(R/10.)
|
||||
norm = np.sum(RE, axis=1)
|
||||
RE = RE/norm[:, np.newaxis]
|
||||
return torch.from_numpy(RE)
|
||||
|
||||
def smooth_log_mag(Y_true, Y_pred, filterbank):
|
||||
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
|
||||
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
|
||||
|
||||
loss = torch.abs(
|
||||
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
|
||||
)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
class MRSTFTLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[2048, 1024, 512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
window='hann_window',
|
||||
fs=16000,
|
||||
log_mag_weight=1,
|
||||
sc_weight=0,
|
||||
wsc_weight=0,
|
||||
smooth_log_mag_weight=0,
|
||||
sxcorr_weight=0):
|
||||
super().__init__()
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.window = window
|
||||
self.log_mag_weight = log_mag_weight
|
||||
self.sc_weight = sc_weight
|
||||
self.wsc_weight = wsc_weight
|
||||
self.smooth_log_mag_weight = smooth_log_mag_weight
|
||||
self.sxcorr_weight = sxcorr_weight
|
||||
self.fs = fs
|
||||
|
||||
# weights for SFM weighted spectral convergence loss
|
||||
self.wsc_weights = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
width = min(11, int(1000 * fft_size / self.fs + .5))
|
||||
width += width % 2
|
||||
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
|
||||
create_weight_matrix(fft_size // 2 + 1, width),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
# filterbanks for smooth log magnitude loss
|
||||
self.filterbanks = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
|
||||
gen_filterbank(fft_size//2),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def __call__(self, y_true, y_pred):
|
||||
|
||||
|
||||
lm_loss = torch.zeros(1, device=y_true.device)
|
||||
sc_loss = torch.zeros(1, device=y_true.device)
|
||||
wsc_loss = torch.zeros(1, device=y_true.device)
|
||||
slm_loss = torch.zeros(1, device=y_true.device)
|
||||
sxcorr_loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for fft_size in self.fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
win_size = fft_size
|
||||
|
||||
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
|
||||
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
|
||||
|
||||
if self.log_mag_weight > 0:
|
||||
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
if self.sc_weight > 0:
|
||||
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
|
||||
|
||||
if self.wsc_weight > 0:
|
||||
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
|
||||
|
||||
if self.smooth_log_mag_weight > 0:
|
||||
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
|
||||
|
||||
if self.sxcorr_weight > 0:
|
||||
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
|
||||
|
||||
|
||||
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
|
||||
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
|
||||
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
|
||||
|
||||
return total_loss
|
||||
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
import scipy.signal
|
||||
|
||||
|
||||
from utils.layers.fir import FIR
|
||||
|
||||
class TDLowpass(torch.nn.Module):
|
||||
def __init__(self, numtaps, cutoff, power=2):
|
||||
super().__init__()
|
||||
|
||||
self.b = scipy.signal.firwin(numtaps, cutoff)
|
||||
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
|
||||
self.power = power
|
||||
|
||||
def forward(self, y_true, y_pred):
|
||||
|
||||
assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
|
||||
|
||||
diff = y_true - y_pred
|
||||
diff_lp = torch.nn.functional.conv1d(diff, self.weight)
|
||||
|
||||
loss = torch.mean(torch.abs(diff_lp ** self.power))
|
||||
|
||||
return loss, diff_lp
|
||||
|
||||
def get_freqz(self):
|
||||
freq, response = scipy.signal.freqz(self.b)
|
||||
|
||||
return freq, response
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user