add some code

This commit is contained in:
2025-09-05 13:25:11 +08:00
parent 9ff0a99e7a
commit 3cf1229a85
8911 changed files with 2535396 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
from . import quantization
from . import sparsification

View File

@@ -0,0 +1 @@
from .softquant import soft_quant, remove_soft_quant

View File

@@ -0,0 +1,113 @@
import torch
@torch.no_grad()
def compute_optimal_scale(weight):
with torch.no_grad():
n_out, n_in = weight.shape
assert n_in % 4 == 0
if n_out % 8:
# add padding
pad = n_out - n_out % 8
weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
scale_max = weight_max_abs / 127
scale_sum = weight_max_sum / 129
scale = torch.maximum(scale_max, scale_sum)
return scale[:n_out]
@torch.no_grad()
def q_scaled_noise(module, weight):
if isinstance(module, torch.nn.Conv1d):
w = weight.permute(0, 2, 1).flatten(1)
noise = torch.rand_like(w) - 0.5
noise[w == 0] = 0 # ignore zero entries from sparsification
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
elif isinstance(module, torch.nn.ConvTranspose1d):
i, o, k = weight.shape
w = weight.permute(2, 1, 0).reshape(k * o, i)
noise = torch.rand_like(w) - 0.5
noise[w == 0] = 0 # ignore zero entries from sparsification
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(k, o, i).permute(2, 1, 0)
elif len(weight.shape) == 2:
noise = torch.rand_like(weight) - 0.5
noise[weight == 0] = 0 # ignore zero entries from sparsification
scale = compute_optimal_scale(weight)
noise = noise * scale.unsqueeze(-1)
else:
raise ValueError('unknown quantization setting')
return noise
class SoftQuant:
name: str
def __init__(self, names: str, scale: float) -> None:
self.names = names
self.quantization_noise = None
self.scale = scale
def __call__(self, module, inputs, *args, before=True):
if not module.training: return
if before:
self.quantization_noise = dict()
for name in self.names:
weight = getattr(module, name)
if self.scale is None:
self.quantization_noise[name] = q_scaled_noise(module, weight)
else:
self.quantization_noise[name] = \
self.scale * (torch.rand_like(weight) - 0.5)
with torch.no_grad():
weight.data[:] = weight + self.quantization_noise[name]
else:
for name in self.names:
weight = getattr(module, name)
with torch.no_grad():
weight.data[:] = weight - self.quantization_noise[name]
self.quantization_noise = None
def apply(module, names=['weight'], scale=None):
fn = SoftQuant(names, scale)
for name in names:
if not hasattr(module, name):
raise ValueError("")
fn_before = lambda *x : fn(*x, before=True)
fn_after = lambda *x : fn(*x, before=False)
setattr(fn_before, 'sqm', fn)
setattr(fn_after, 'sqm', fn)
module.register_forward_pre_hook(fn_before)
module.register_forward_hook(fn_after)
module
return fn
def soft_quant(module, names=['weight'], scale=None):
fn = SoftQuant.apply(module, names, scale)
return module
def remove_soft_quant(module, names=['weight']):
for k, hook in module._forward_pre_hooks.items():
if hasattr(hook, 'sqm'):
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
del module._forward_pre_hooks[k]
for k, hook in module._forward_hooks.items():
if hasattr(hook, 'sqm'):
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
del module._forward_hooks[k]
return module

View File

@@ -0,0 +1,2 @@
from .relegance import relegance_gradient_weighting, relegance_create_tconv_kernel, relegance_map_relevance_to_input_domain, relegance_resize_relevance_to_input_size
from .meta_critic import MetaCritic

View File

@@ -0,0 +1,85 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
class MetaCritic():
def __init__(self, normalize=False, gamma=0.9, beta=0.0, joint_stats=False):
""" Class for assessing relevance of discriminator scores
Args:
gamma (float, optional): update rate for tracking discriminator stats. Defaults to 0.9.
beta (float, optional): Miminum confidence related threshold. Defaults to 0.0.
"""
self.normalize = normalize
self.gamma = gamma
self.beta = beta
self.joint_stats = joint_stats
self.disc_stats = dict()
def __call__(self, disc_id, real_scores, generated_scores):
""" calculates relevance from normalized scores
Args:
disc_id (any valid key): id for tracking discriminator statistics
real_scores (torch.tensor): scores for real data
generated_scores (torch.tensor): scores for generated data; expecting device to match real_scores.device
Returns:
torch.tensor: output-domain relevance
"""
if self.normalize:
real_std = torch.std(real_scores.detach()).cpu().item()
gen_std = torch.std(generated_scores.detach()).cpu().item()
std = (real_std**2 + gen_std**2) ** .5
mean = torch.mean(real_scores.detach()).cpu().item() - torch.mean(generated_scores.detach()).cpu().item()
key = 0 if self.joint_stats else disc_id
if key in self.disc_stats:
self.disc_stats[key]['std'] = self.gamma * self.disc_stats[key]['std'] + (1 - self.gamma) * std
self.disc_stats[key]['mean'] = self.gamma * self.disc_stats[key]['mean'] + (1 - self.gamma) * mean
else:
self.disc_stats[key] = {
'std': std + 1e-5,
'mean': mean
}
std = self.disc_stats[key]['std']
mean = self.disc_stats[key]['mean']
else:
mean, std = 0, 1
relevance = torch.relu((real_scores - generated_scores - mean) / std + mean - self.beta)
if False: print(f"relevance({disc_id}): {relevance.min()=} {relevance.max()=} {relevance.mean()=}")
return relevance

View File

@@ -0,0 +1,449 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
import torch.nn.functional as F
def view_one_hot(index, length):
vec = length * [1]
vec[index] = -1
return vec
def create_smoothing_kernel(widths, gamma=1.5):
""" creates a truncated gaussian smoothing kernel for the given widths
Parameters:
-----------
widths: list[Int] or torch.LongTensor
specifies the shape of the smoothing kernel, entries must be > 0.
gamma: float, optional
decay factor for gaussian relative to kernel size
Returns:
--------
kernel: torch.FloatTensor
"""
widths = torch.LongTensor(widths)
num_dims = len(widths)
assert(widths.min() > 0)
centers = widths.float() / 2 - 0.5
sigmas = gamma * (centers + 1)
vals = []
vals= [((torch.arange(widths[i]) - centers[i]) / sigmas[i]) ** 2 for i in range(num_dims)]
vals = sum([vals[i].view(view_one_hot(i, num_dims)) for i in range(num_dims)])
kernel = torch.exp(- vals)
kernel = kernel / kernel.sum()
return kernel
def create_partition_kernel(widths, strides):
""" creates a partition kernel for mapping a convolutional network output back to the input domain
Given a fully convolutional network with receptive field of shape widths and the given strides, this
function construncts an intorpolation kernel whose tranlations by multiples of the given strides form
a partition of one on the input domain.
Parameter:
----------
widths: list[Int] or torch.LongTensor
shape of receptive field
strides: list[Int] or torch.LongTensor
total strides of convolutional network
Returns:
kernel: torch.FloatTensor
"""
num_dims = len(widths)
assert num_dims == len(strides) and num_dims in {1, 2, 3}
convs = {1 : F.conv1d, 2 : F.conv2d, 3 : F.conv3d}
widths = torch.LongTensor(widths)
strides = torch.LongTensor(strides)
proto_kernel = torch.ones(torch.minimum(strides, widths).tolist())
# create interpolation kernel eta
eta_widths = widths - strides + 1
if eta_widths.min() <= 0:
print("[create_partition_kernel] warning: receptive field does not cover input domain")
eta_widths = torch.maximum(eta_widths, torch.ones_like(eta_widths))
eta = create_smoothing_kernel(eta_widths).view(1, 1, *eta_widths.tolist())
padding = torch.repeat_interleave(eta_widths - 1, 2, 0).tolist()[::-1] # ordering of dimensions for padding and convolution functions is reversed in torch
padded_proto_kernel = F.pad(proto_kernel, padding)
padded_proto_kernel = padded_proto_kernel.view(1, 1, *padded_proto_kernel.shape)
kernel = convs[num_dims](padded_proto_kernel, eta)
return kernel
def receptive_field(conv_model, input_shape, output_position):
""" estimates boundaries of receptive field connected to output_position via autograd
Parameters:
-----------
conv_model: nn.Module or autograd function
function or model implementing fully convolutional model
input_shape: List[Int]
input shape ignoring batch dimension, i.e. [num_channels, dim1, dim2, ...]
output_position: List[Int]
output position for which the receptive field is determined; the function raises an exception
if output_position is out of bounds for the given input_shape.
Returns:
--------
low: List[Int]
start indices of receptive field
high: List[Int]
stop indices of receptive field
"""
x = torch.randn((1,) + tuple(input_shape), requires_grad=True)
y = conv_model(x)
# collapse channels and remove batch dimension
y = torch.sum(y, 1)[0]
# create mask
mask = torch.zeros_like(y)
index = [torch.tensor(i) for i in output_position]
try:
mask.index_put_(index, torch.tensor(1, dtype=mask.dtype))
except IndexError:
raise ValueError('output_position out of bounds')
(mask * y).sum().backward()
# sum over channels and remove batch dimension
grad = torch.sum(x.grad, dim=1)[0]
tmp = torch.nonzero(grad, as_tuple=True)
low = [t.min().item() for t in tmp]
high = [t.max().item() for t in tmp]
return low, high
def estimate_conv_parameters(model, num_channels, num_dims, width, max_stride=10):
""" attempts to estimate receptive field size, strides and left paddings for given model
Parameters:
-----------
model: nn.Module or autograd function
fully convolutional model for which parameters are estimated
num_channels: Int
number of input channels for model
num_dims: Int
number of input dimensions for model (without channel dimension)
width: Int
width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
max_stride: Int, optional
assumed maximal stride of the model for any dimension, when set too low the function may fail for
any value of width
Returns:
--------
receptive_field_size: List[Int]
receptive field size in all dimension
strides: List[Int]
stride in all dimensions
left_paddings: List[Int]
left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
Raises:
-------
ValueError, KeyError
"""
input_shape = [num_channels] + num_dims * [width]
output_position1 = num_dims * [width // (2 * max_stride)]
output_position2 = num_dims * [width // (2 * max_stride) + 1]
low1, high1 = receptive_field(model, input_shape, output_position1)
low2, high2 = receptive_field(model, input_shape, output_position2)
widths1 = [h - l + 1 for l, h in zip(low1, high1)]
widths2 = [h - l + 1 for l, h in zip(low2, high2)]
if not all([w1 - w2 == 0 for w1, w2 in zip(widths1, widths2)]) or not all([l1 != l2 for l1, l2 in zip(low1, low2)]):
raise ValueError("[estimate_strides]: widths to small to determine strides")
receptive_field_size = widths1
strides = [l2 - l1 for l1, l2 in zip(low1, low2)]
left_paddings = [s * p - l for l, s, p in zip(low1, strides, output_position1)]
return receptive_field_size, strides, left_paddings
def inspect_conv_model(model, num_channels, num_dims, max_width=10000, width_hint=None, stride_hint=None, verbose=False):
""" determines size of receptive field, strides and padding probabilistically
Parameters:
-----------
model: nn.Module or autograd function
fully convolutional model for which parameters are estimated
num_channels: Int
number of input channels for model
num_dims: Int
number of input dimensions for model (without channel dimension)
max_width: Int
maximum width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
verbose: bool, optional
if true, the function prints parameters for individual trials
Returns:
--------
receptive_field_size: List[Int]
receptive field size in all dimension
strides: List[Int]
stride in all dimensions
left_paddings: List[Int]
left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
Raises:
-------
ValueError
"""
max_stride = max_width // 2
stride = max_stride // 100
width = max_width // 100
if width_hint is not None: width = 2 * width_hint
if stride_hint is not None: stride = stride_hint
did_it = False
while width < max_width and stride < max_stride:
try:
if verbose: print(f"[inspect_conv_model] trying parameters {width=}, {stride=}")
receptive_field_size, strides, left_paddings = estimate_conv_parameters(model, num_channels, num_dims, width, stride)
did_it = True
except:
pass
if did_it: break
width *= 2
if width >= max_width and stride < max_stride:
stride *= 2
width = 2 * stride
if not did_it:
raise ValueError(f'could not determine conv parameter with given max_width={max_width}')
return receptive_field_size, strides, left_paddings
class GradWeight(torch.autograd.Function):
def __init__(self):
super().__init__()
@staticmethod
def forward(ctx, x, weight):
ctx.save_for_backward(weight)
return x.clone()
@staticmethod
def backward(ctx, grad_output):
weight, = ctx.saved_tensors
grad_input = grad_output * weight
return grad_input, None
# API
def relegance_gradient_weighting(x, weight):
"""
Args:
x (torch.tensor): input tensor
weight (torch.tensor or None): weight tensor for gradients of x; if None, no gradient weighting will be applied in backward pass
Returns:
torch.tensor: the unmodified input tensor x
Raises:
RuntimeError: if estimation of parameters fails due to exceeded compute budget
"""
if weight is None:
return x
else:
return GradWeight.apply(x, weight)
def relegance_create_tconv_kernel(model, num_channels, num_dims, width_hint=None, stride_hint=None, verbose=False):
""" creates parameters for mapping back output domain relevance to input tomain
Args:
model (nn.Module or autograd.Function): fully convolutional model
num_channels (int): number of input channels to model
num_dims (int): number of input dimensions of model (without channel and batch dimension)
width_hint(int or None): optional hint at maximal width of receptive field
stride_hint(int or None): optional hint at maximal stride
Returns:
dict: contains kernel, kernel dimensions, strides and left paddings for transposed convolution
"""
max_width = int(100000 / (10 ** num_dims))
did_it = False
try:
receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
did_it = True
except:
# try once again with larger max_width
max_width *= 10
# crash if exception is raised
try:
if not did_it: receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
except:
raise RuntimeError("could not determine parameters within given compute budget")
partition_kernel = create_partition_kernel(receptive_field_size, strides)
partition_kernel = torch.repeat_interleave(partition_kernel, num_channels, 1)
tconv_parameters = {
'kernel': partition_kernel,
'receptive_field_shape': receptive_field_size,
'stride': strides,
'left_padding': left_paddings,
'num_dims': num_dims
}
return tconv_parameters
def relegance_map_relevance_to_input_domain(od_relevance, tconv_parameters):
""" maps output-domain relevance to input-domain relevance via transpose convolution
Args:
od_relevance (torch.tensor): output-domain relevance
tconv_parameters (dict): parameter dict as created by relegance_create_tconv_kernel
Returns:
torch.tensor: input-domain relevance. The tensor is left aligned, i.e. the all-zero index of the output corresponds to the all-zero index of the discriminator input.
Otherwise, the size of the output tensor does not need to match the size of the discriminator input. Use relegance_resize_relevance_to_input_size for a
convenient way to adjust the output to the correct size.
Raises:
ValueError: if number of dimensions is not supported
"""
kernel = tconv_parameters['kernel'].to(od_relevance.device)
rf_shape = tconv_parameters['receptive_field_shape']
stride = tconv_parameters['stride']
left_padding = tconv_parameters['left_padding']
num_dims = len(kernel.shape) - 2
# repeat boundary values
od_padding = [rf_shape[i//2] // stride[i//2] + 1 for i in range(2 * num_dims)]
padded_od_relevance = F.pad(od_relevance, od_padding[::-1], mode='replicate')
od_padding = od_padding[::2]
# apply mapping and left trimming
if num_dims == 1:
id_relevance = F.conv_transpose1d(padded_od_relevance, kernel, stride=stride)
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :]
elif num_dims == 2:
id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:]
elif num_dims == 3:
id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:, left_padding[2] + stride[2] * od_padding[2] :]
else:
raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
return id_relevance
def relegance_resize_relevance_to_input_size(reference_input, relevance):
""" adjusts size of relevance tensor to reference input size
Args:
reference_input (torch.tensor): discriminator input tensor for reference
relevance (torch.tensor): input-domain relevance corresponding to input tensor reference_input
Returns:
torch.tensor: resized relevance
Raises:
ValueError: if number of dimensions is not supported
"""
resized_relevance = torch.zeros_like(reference_input)
num_dims = len(reference_input.shape) - 2
with torch.no_grad():
if num_dims == 1:
resized_relevance[:] = relevance[..., : min(reference_input.size(-1), relevance.size(-1))]
elif num_dims == 2:
resized_relevance[:] = relevance[..., : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
elif num_dims == 3:
resized_relevance[:] = relevance[..., : min(reference_input.size(-3), relevance.size(-3)), : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
else:
raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
return resized_relevance

View File

@@ -0,0 +1,6 @@
from .gru_sparsifier import GRUSparsifier
from .conv1d_sparsifier import Conv1dSparsifier
from .conv_transpose1d_sparsifier import ConvTranspose1dSparsifier
from .linear_sparsifier import LinearSparsifier
from .common import sparsify_matrix, calculate_gru_flops_per_step
from .utils import mark_for_sparsification, create_sparsifier

View File

@@ -0,0 +1,58 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
class BaseSparsifier:
def __init__(self, task_list, start, stop, interval, exponent=3):
# just copying parameters...
self.start = start
self.stop = stop
self.interval = interval
self.exponent = exponent
self.task_list = task_list
# ... and setting counter to 0
self.step_counter = 0
def step(self, verbose=False):
# compute current interpolation factor
self.step_counter += 1
if self.step_counter < self.start:
return
elif self.step_counter < self.stop:
# update only every self.interval-th interval
if self.step_counter % self.interval:
return
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
else:
alpha = 0
self.sparsify(alpha, verbose=verbose)

View File

@@ -0,0 +1,123 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
debug=True
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
""" sparsifies matrix with specified block size
Parameters:
-----------
matrix : torch.tensor
matrix to sparsify
density : int
target density
block_size : [int, int]
block size dimensions
keep_diagonal : bool
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
"""
m, n = matrix.shape
m1, n1 = block_size
if m % m1 or n % n1:
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
# extract diagonal if keep_diagonal = True
if keep_diagonal:
if m != n:
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
to_spare = torch.diag(torch.diag(matrix))
matrix = matrix - to_spare
else:
to_spare = torch.zeros_like(matrix)
# calculate energy in sub-blocks
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
x = x ** 2
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
number_of_blocks = (m * n) // (m1 * n1)
number_of_survivors = round(number_of_blocks * density)
# masking threshold
if number_of_survivors == 0:
threshold = 0
else:
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
# create mask
mask = torch.ones_like(block_energies)
mask[block_energies < threshold] = 0
mask = torch.repeat_interleave(mask, m1, dim=0)
mask = torch.repeat_interleave(mask, n1, dim=1)
# perform masking
masked_matrix = mask * matrix + to_spare
if return_mask:
return masked_matrix, mask
else:
return masked_matrix
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
input_size = gru.input_size
hidden_size = gru.hidden_size
flops = 0
input_density = (
sparsification_dict.get('W_ir', [1])[0]
+ sparsification_dict.get('W_in', [1])[0]
+ sparsification_dict.get('W_iz', [1])[0]
) / 3
recurrent_density = (
sparsification_dict.get('W_hr', [1])[0]
+ sparsification_dict.get('W_hn', [1])[0]
+ sparsification_dict.get('W_hz', [1])[0]
) / 3
# input matrix vector multiplications
if not drop_input:
flops += 2 * 3 * input_size * hidden_size * input_density
# recurrent matrix vector multiplications
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
# biases
flops += 6 * hidden_size
# activations estimated by 10 flops per activation
flops += 30 * hidden_size
return flops

View File

@@ -0,0 +1,133 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix, debug
class Conv1dSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (conv1d, params), where conv1d is an instance
of torch.nn.Conv1d and params is a tuple (density, [m, n]),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> conv = torch.nn.Conv1d(8, 16, 8)
>>> params = (0.2, [8, 4])
>>> sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
super().__init__(task_list, start, stop, interval, exponent=3)
self.last_mask = None
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
with torch.no_grad():
for conv, params in self.task_list:
# reshape weight
if hasattr(conv, 'weight_v'):
weight = conv.weight_v
else:
weight = conv.weight
i, o, k = weight.shape
w = weight.permute(0, 2, 1).flatten(1)
target_density, block_size = params
density = alpha + (1 - alpha) * target_density
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
w = w.reshape(i, k, o).permute(0, 2, 1)
weight[:] = w
if self.last_mask is not None:
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
print("weight resurrection in conv.weight")
self.last_mask = new_mask
if verbose:
print(f"conv1d_sparsier[{self.step_counter}]: {density=}")
if __name__ == "__main__":
print("Testing sparsifier")
import torch
conv = torch.nn.Conv1d(8, 16, 8)
params = (0.2, [8, 4])
sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 5)
for i in range(100):
sparsifier.step(verbose=True)
print(conv.weight)

View File

@@ -0,0 +1,134 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix, debug
class ConvTranspose1dSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (conv1d, params), where conv1d is an instance
of torch.nn.Conv1d and params is a tuple (density, [m, n]),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> conv = torch.nn.ConvTranspose1d(8, 16, 8)
>>> params = (0.2, [8, 4])
>>> sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
super().__init__(task_list, start, stop, interval, exponent=3)
self.last_mask = None
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
with torch.no_grad():
for conv, params in self.task_list:
# reshape weight
if hasattr(conv, 'weight_v'):
weight = conv.weight_v
else:
weight = conv.weight
i, o, k = weight.shape
w = weight.permute(2, 1, 0).reshape(k * o, i)
target_density, block_size = params
density = alpha + (1 - alpha) * target_density
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
w = w.reshape(k, o, i).permute(2, 1, 0)
weight[:] = w
if self.last_mask is not None:
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
print("weight resurrection in conv.weight")
self.last_mask = new_mask
if verbose:
print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}")
if __name__ == "__main__":
print("Testing sparsifier")
import torch
conv = torch.nn.ConvTranspose1d(8, 16, 4, 4)
params = (0.2, [8, 4])
sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 5)
for i in range(100):
sparsifier.step(verbose=True)
print(conv.weight)

View File

@@ -0,0 +1,178 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix, debug
class GRUSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
should be kept.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> gru = torch.nn.GRU(10, 20)
>>> sparsify_dict = {
... 'W_ir' : (0.5, [2, 2], False),
... 'W_iz' : (0.6, [2, 2], False),
... 'W_in' : (0.7, [2, 2], False),
... 'W_hr' : (0.1, [4, 4], True),
... 'W_hz' : (0.2, [4, 4], True),
... 'W_hn' : (0.3, [4, 4], True),
... }
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
super().__init__(task_list, start, stop, interval, exponent=3)
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
with torch.no_grad():
for gru, params in self.task_list:
hidden_size = gru.hidden_size
# input weights
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
if key in params:
if hasattr(gru, 'weight_ih_l0_v'):
weight = gru.weight_ih_l0_v
else:
weight = gru.weight_ih_l0
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
weight[i * hidden_size : (i + 1) * hidden_size, : ],
density, # density
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] * new_mask == new_mask) and debug:
print("weight resurrection in weight_ih_l0_v")
self.last_masks[key] = new_mask
# recurrent weights
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
if key in params:
if hasattr(gru, 'weight_hh_l0_v'):
weight = gru.weight_hh_l0_v
else:
weight = gru.weight_hh_l0
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
weight[i * hidden_size : (i + 1) * hidden_size, : ],
density,
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] * new_mask == new_mask) and True:
print("weight resurrection in weight_hh_l0_v")
self.last_masks[key] = new_mask
if __name__ == "__main__":
print("Testing sparsifier")
gru = torch.nn.GRU(10, 20)
sparsify_dict = {
'W_ir' : (0.5, [2, 2], False),
'W_iz' : (0.6, [2, 2], False),
'W_in' : (0.7, [2, 2], False),
'W_hr' : (0.1, [4, 4], True),
'W_hz' : (0.2, [4, 4], True),
'W_hn' : (0.3, [4, 4], True),
}
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
for i in range(100):
sparsifier.step(verbose=True)
print(gru.weight_hh_l0)

View File

@@ -0,0 +1,128 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix
class LinearSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (linear, params), where linear is an instance
of torch.nn.Linear and params is a tuple (density, [m, n]),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> linear = torch.nn.Linear(8, 16)
>>> params = (0.2, [8, 4])
>>> sparsifier = LinearSparsifier([(linear, params)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
super().__init__(task_list, start, stop, interval, exponent=3)
self.last_mask = None
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
with torch.no_grad():
for linear, params in self.task_list:
if hasattr(linear, 'weight_v'):
weight = linear.weight_v
else:
weight = linear.weight
target_density, block_size = params
density = alpha + (1 - alpha) * target_density
weight[:], new_mask = sparsify_matrix(weight, density, block_size, return_mask=True)
if self.last_mask is not None:
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
print("weight resurrection in conv.weight")
self.last_mask = new_mask
if verbose:
print(f"linear_sparsifier[{self.step_counter}]: {density=}")
if __name__ == "__main__":
print("Testing sparsifier")
import torch
linear = torch.nn.Linear(8, 16)
params = (0.2, [4, 2])
sparsifier = LinearSparsifier([(linear, params)], 0, 100, 5)
for i in range(100):
sparsifier.step(verbose=True)
print(linear.weight)

View File

@@ -0,0 +1,64 @@
import torch
from dnntools.sparsification import GRUSparsifier, LinearSparsifier, Conv1dSparsifier, ConvTranspose1dSparsifier
def mark_for_sparsification(module, params):
setattr(module, 'sparsify', True)
setattr(module, 'sparsification_params', params)
return module
def create_sparsifier(module, start, stop, interval):
sparsifier_list = []
for m in module.modules():
if hasattr(m, 'sparsify'):
if isinstance(m, torch.nn.GRU):
sparsifier_list.append(
GRUSparsifier([(m, m.sparsification_params)], start, stop, interval)
)
elif isinstance(m, torch.nn.Linear):
sparsifier_list.append(
LinearSparsifier([(m, m.sparsification_params)], start, stop, interval)
)
elif isinstance(m, torch.nn.Conv1d):
sparsifier_list.append(
Conv1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
)
elif isinstance(m, torch.nn.ConvTranspose1d):
sparsifier_list.append(
ConvTranspose1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
)
else:
print(f"[create_sparsifier] warning: module {m} marked for sparsification but no suitable sparsifier exists.")
def sparsify(verbose=False):
for sparsifier in sparsifier_list:
sparsifier.step(verbose)
return sparsify
def count_parameters(model, verbose=False):
total = 0
for name, p in model.named_parameters():
count = torch.ones_like(p).sum().item()
if verbose:
print(f"{name}: {count} parameters")
total += count
return total
def estimate_nonzero_parameters(module):
num_zero_parameters = 0
if hasattr(module, 'sparsify'):
params = module.sparsification_params
if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.ConvTranspose1d):
num_zero_parameters = torch.ones_like(module.weight).sum().item() * (1 - params[0])
elif isinstance(module, torch.nn.GRU):
num_zero_parameters = module.input_size * module.hidden_size * (3 - params['W_ir'][0] - params['W_iz'][0] - params['W_in'][0])
num_zero_parameters += module.hidden_size * module.hidden_size * (3 - params['W_hr'][0] - params['W_hz'][0] - params['W_hn'][0])
elif isinstance(module, torch.nn.Linear):
num_zero_parameters = module.in_features * module.out_features * params[0]
else:
raise ValueError(f'unknown sparsification method for module of type {type(module)}')

View File

@@ -0,0 +1 @@
torch

View File

@@ -0,0 +1,48 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
#!/usr/bin/env/python
import os
from setuptools import setup
lib_folder = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(lib_folder, 'requirements.txt'), 'r') as f:
install_requires = list(f.read().splitlines())
print(install_requires)
setup(name='dnntools',
version='1.0',
author='Jan Buethe',
author_email='jbuethe@amazon.de',
description='Non-Standard tools for deep neural network training with PyTorch',
packages=['dnntools', 'dnntools.sparsification', 'dnntools.quantization'],
install_requires=install_requires
)

View File

@@ -0,0 +1,54 @@
# Framewise Auto-Regressive GAN (FARGAN)
Implementation of FARGAN, a low-complexity neural vocoder. Pre-trained models
are provided as C code in the dnn/ directory with the corresponding model in
dnn/models/ directory (name starts with fargan_). If you don't want to train
a new FARGAN model, you can skip straight to the Inference section.
## Data preparation
For data preparation you need to build Opus as detailed in the top-level README.
You will need to use the --enable-deep-plc configure option.
The build will produce an executable named "dump_data".
To prepare the training data, run:
```
./dump_data -train in_speech.pcm out_features.f32 out_speech.pcm
```
Where the in_speech.pcm speech file is a raw 16-bit PCM file sampled at 16 kHz.
The speech data used for training the model can be found at:
https://media.xiph.org/lpcnet/speech/tts_speech_negative_16k.sw
## Training
To perform pre-training, run the following command:
```
python ./train_fargan.py out_features.f32 out_speech.pcm output_dir --epochs 400 --batch-size 4096 --lr 0.002 --cuda-visible-devices 0
```
Once pre-training is complete, run adversarial training using:
```
python adv_train_fargan.py out_features.f32 out_speech.pcm output_dir --lr 0.000002 --reg-weight 5 --batch-size 160 --cuda-visible-devices 0 --initial-checkpoint output_dir/checkpoints/fargan_400.pth
```
The final model will be in output_dir/checkpoints/fargan_adv_50.pth.
The model can optionally be converted to C using:
```
python dump_fargan_weights.py output_dir/checkpoints/fargan_adv_50.pth fargan_c_dir
```
which will create a fargan_data.c and a fargan_data.h file in the fargan_c_dir directory.
Copy these files to the opus/dnn/ directory (replacing the existing ones) and recompile Opus.
## Inference
To run the inference, start by generating the features from the audio using:
```
./fargan_demo -features test_speech.pcm test_features.f32
```
Synthesis can be achieved either using the PyTorch code or the C code.
To synthesize from PyTorch, run:
```
python test_fargan.py output_dir/checkpoints/fargan_adv_50.pth test_features.f32 output_speech.pcm
```
To synthesize from the C code, run:
```
./fargan_demo -fargan-synthesis test_features.f32 output_speech.pcm
```

View File

@@ -0,0 +1,278 @@
import os
import argparse
import random
import numpy as np
import sys
import math as m
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
import fargan
from dataset import FARGANDataset
from stft_loss import *
source_dir = os.path.split(os.path.abspath(__file__))[0]
sys.path.append(os.path.join(source_dir, "../osce/"))
import models as osce_models
def fmap_loss(scores_real, scores_gen):
num_discs = len(scores_real)
loss_feat = 0
for k in range(num_discs):
num_layers = len(scores_gen[k]) - 1
f = 4 / num_discs / num_layers
for l in range(num_layers):
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
return loss_feat
parser = argparse.ArgumentParser()
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
parser.add_argument('output', type=str, help='path to output folder')
parser.add_argument('--suffix', type=str, help="model name suffix", default="")
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128)
training_group.add_argument('--lr', type=float, help='learning rate, default: 5e-4', default=5e-4)
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 50', default=50)
training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 60', default=60)
training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 0.0', default=0.0)
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
training_group.add_argument('--reg-weight', type=float, help='regression loss weight, default: 1.0', default=1.0)
training_group.add_argument('--fmap-weight', type=float, help='feature matchin loss weight, default: 1.0', default=1.)
args = parser.parse_args()
if args.cuda_visible_devices != None:
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
# checkpoints
checkpoint_dir = os.path.join(args.output, 'checkpoints')
checkpoint = dict()
os.makedirs(checkpoint_dir, exist_ok=True)
# training parameters
batch_size = args.batch_size
lr = args.lr
epochs = args.epochs
sequence_length = args.sequence_length
lr_decay = args.lr_decay
adam_betas = [0.8, 0.99]
adam_eps = 1e-8
features_file = args.features
signal_file = args.signal
# model parameters
cond_size = args.cond_size
checkpoint['batch_size'] = batch_size
checkpoint['lr'] = lr
checkpoint['lr_decay'] = lr_decay
checkpoint['epochs'] = epochs
checkpoint['sequence_length'] = sequence_length
checkpoint['adam_betas'] = adam_betas
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint['model_args'] = ()
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant}
print(checkpoint['model_kwargs'])
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
#discriminator
disc_name = 'fdmresdisc'
disc = osce_models.model_dict[disc_name](
architecture='free',
design='f_down',
fft_sizes_16k=[2**n for n in range(6, 12)],
freq_roi=[0, 7400],
max_channels=256,
noise_gain=0.0
)
if type(args.initial_checkpoint) != type(None):
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
checkpoint['state_dict'] = model.state_dict()
dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
optimizer_disc = torch.optim.AdamW([p for p in disc.parameters() if p.requires_grad], lr=lr, betas=adam_betas, eps=adam_eps)
# learning rate scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
scheduler_disc = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer_disc, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
states = None
spect_loss = MultiResolutionSTFTLoss(device).to(device)
for param in model.parameters():
param.requires_grad = False
batch_count = 0
if __name__ == '__main__':
model.to(device)
disc.to(device)
for epoch in range(1, epochs + 1):
m_r = 0
m_f = 0
s_r = 1
s_f = 1
running_cont_loss = 0
running_disc_loss = 0
running_gen_loss = 0
running_fmap_loss = 0
running_reg_loss = 0
running_wc = 0
print(f"training epoch {epoch}...")
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (features, periods, target, lpc) in enumerate(tepoch):
if epoch == 1 and i == 400:
for param in model.parameters():
param.requires_grad = True
for param in model.cond_net.parameters():
param.requires_grad = False
for param in model.sig_net.cond_gain_dense.parameters():
param.requires_grad = False
optimizer.zero_grad()
features = features.to(device)
#lpc = lpc.to(device)
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
#lpc = fargan.interp_lpc(lpc, 4)
periods = periods.to(device)
if True:
target = target[:, :sequence_length*160]
#lpc = lpc[:,:sequence_length*4,:]
features = features[:,:sequence_length+4,:]
periods = periods[:,:sequence_length+4]
else:
target=target[::2, :]
#lpc=lpc[::2,:]
features=features[::2,:]
periods=periods[::2,:]
target = target.to(device)
#target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
#nb_pre = random.randrange(1, 6)
nb_pre = 2
pre = target[:, :nb_pre*160]
output, _ = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
output = torch.cat([pre, output], -1)
# discriminator update
scores_gen = disc(output.detach().unsqueeze(1))
scores_real = disc(target.unsqueeze(1))
disc_loss = 0
for scale in scores_gen:
disc_loss += ((scale[-1]) ** 2).mean()
m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item()
s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item()
for scale in scores_real:
disc_loss += ((1 - scale[-1]) ** 2).mean()
m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item()
s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item()
disc_loss = 0.5 * disc_loss / len(scores_gen)
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
running_wc += winning_chance
disc.zero_grad()
disc_loss.backward()
optimizer_disc.step()
# model update
scores_gen = disc(output.unsqueeze(1))
if False: # todo: check whether that makes a difference
with torch.no_grad():
scores_real = disc(target.unsqueeze(1))
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], output[:, nb_pre*160:nb_pre*160+80])
specc_loss = spect_loss(output, target.detach())
reg_loss = (.00*cont_loss + specc_loss)
loss_gen = 0
for scale in scores_gen:
loss_gen += ((1 - scale[-1]) ** 2).mean() / len(scores_gen)
feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen)
reg_weight = args.reg_weight# + 15./(1 + (batch_count/7600.))
gen_loss = reg_weight * reg_loss + feat_loss + loss_gen
model.zero_grad()
gen_loss.backward()
optimizer.step()
#model.clip_weights()
scheduler.step()
scheduler_disc.step()
running_cont_loss += cont_loss.detach().cpu().item()
running_gen_loss += loss_gen.detach().cpu().item()
running_disc_loss += disc_loss.detach().cpu().item()
running_fmap_loss += feat_loss.detach().cpu().item()
running_reg_loss += reg_loss.detach().cpu().item()
tepoch.set_postfix(cont_loss=f"{running_cont_loss/(i+1):8.5f}",
reg_weight=f"{reg_weight:8.5f}",
gen_loss=f"{running_gen_loss/(i+1):8.5f}",
disc_loss=f"{running_disc_loss/(i+1):8.5f}",
fmap_loss=f"{running_fmap_loss/(i+1):8.5f}",
reg_loss=f"{running_reg_loss/(i+1):8.5f}",
wc = f"{running_wc/(i+1):8.5f}",
)
batch_count = batch_count + 1
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_adv_{epoch}.pth')
checkpoint['state_dict'] = model.state_dict()
checkpoint['disc_sate_dict'] = disc.state_dict()
checkpoint['loss'] = {
'cont': running_cont_loss / len(dataloader),
'gen': running_gen_loss / len(dataloader),
'disc': running_disc_loss / len(dataloader),
'fmap': running_fmap_loss / len(dataloader),
'reg': running_reg_loss / len(dataloader)
}
checkpoint['epoch'] = epoch
torch.save(checkpoint, checkpoint_path)

View File

@@ -0,0 +1,61 @@
import torch
import numpy as np
import fargan
class FARGANDataset(torch.utils.data.Dataset):
def __init__(self,
feature_file,
signal_file,
frame_size=160,
sequence_length=15,
lookahead=1,
nb_used_features=20,
nb_features=36):
self.frame_size = frame_size
self.sequence_length = sequence_length
self.lookahead = lookahead
self.nb_features = nb_features
self.nb_used_features = nb_used_features
pcm_chunk_size = self.frame_size*self.sequence_length
self.data = np.memmap(signal_file, dtype='int16', mode='r')
#self.data = self.data[1::2]
self.nb_sequences = len(self.data)//(pcm_chunk_size)-4
self.data = self.data[(4-self.lookahead)*self.frame_size:]
self.data = self.data[:self.nb_sequences*pcm_chunk_size]
#self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size))
sizeof = self.data.strides[-1]
self.data = np.lib.stride_tricks.as_strided(self.data, shape=(self.nb_sequences, pcm_chunk_size*2),
strides=(pcm_chunk_size*sizeof, sizeof))
self.features = np.reshape(np.memmap(feature_file, dtype='float32', mode='r'), (-1, nb_features))
sizeof = self.features.strides[-1]
self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length*2+4, nb_features),
strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
#self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
self.periods = np.round(np.clip(256./2**(self.features[:,:,self.nb_used_features-2]+1.5), 32, 255)).astype('int')
self.lpc = self.features[:, :, self.nb_used_features:]
self.features = self.features[:, :, :self.nb_used_features]
print("lpc_size:", self.lpc.shape)
def __len__(self):
return self.nb_sequences
def __getitem__(self, index):
features = self.features[index, :, :].copy()
if self.lookahead != 0:
lpc = self.lpc[index, 4-self.lookahead:-self.lookahead, :].copy()
else:
lpc = self.lpc[index, 4:, :].copy()
data = self.data[index, :].copy().astype(np.float32) / 2**15
periods = self.periods[index, :].copy()
#lpc = lpc*(self.gamma**np.arange(1,17))
#lpc=lpc[None,:,:]
#lpc = fargan.interp_lpc(lpc, 4)
#lpc=lpc[0,:,:]
return features, periods, data, lpc

View File

@@ -0,0 +1,112 @@
import os
import sys
import argparse
import torch
from torch import nn
sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange'))
import wexchange.torch
import fargan
#from models import model_dict
unquantized = [ 'cond_net.pembed', 'cond_net.fdense1', 'sig_net.cond_gain_dense', 'sig_net.gain_dense_out' ]
unquantized2 = [
'cond_net.pembed',
'cond_net.fdense1',
'cond_net.fconv1',
'cond_net.fconv2',
'cont_net.0',
'sig_net.cond_gain_dense',
'sig_net.fwc0.conv',
'sig_net.fwc0.glu.gate',
'sig_net.dense1_glu.gate',
'sig_net.gru1_glu.gate',
'sig_net.gru2_glu.gate',
'sig_net.gru3_glu.gate',
'sig_net.skip_glu.gate',
'sig_net.skip_dense',
'sig_net.sig_dense_out',
'sig_net.gain_dense_out'
]
description=f"""
This is an unsafe dumping script for FARGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer
and will fail to export any other weights.
Furthermore, the quanitze option relies on the following explicit list of layers to be excluded:
{unquantized}.
Modify this script manually if adjustments are needed.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument('weightfile', type=str, help='weight file path')
parser.add_argument('export_folder', type=str)
parser.add_argument('--export-filename', type=str, default='fargan_data', help='filename for source and header file (.c and .h will be added), defaults to fargan_data')
parser.add_argument('--struct-name', type=str, default='FARGAN', help='name for C struct, defaults to FARGAN')
parser.add_argument('--quantize', action='store_true', help='apply quantization')
if __name__ == "__main__":
args = parser.parse_args()
print(f"loading weights from {args.weightfile}...")
saved_gen= torch.load(args.weightfile, map_location='cpu')
saved_gen['model_args'] = ()
saved_gen['model_kwargs'] = {'cond_size': 256, 'gamma': 0.9}
model = fargan.FARGAN(*saved_gen['model_args'], **saved_gen['model_kwargs'])
model.load_state_dict(saved_gen['state_dict'], strict=False)
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
model.apply(_remove_weight_norm)
print("dumping model...")
quantize_model=args.quantize
output_folder = args.export_folder
os.makedirs(output_folder, exist_ok=True)
writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name, add_typedef=True)
for name, module in model.named_modules():
if quantize_model:
quantize=name not in unquantized
scale = None if quantize else 1/128
else:
quantize=False
scale=1/128
if isinstance(module, nn.Linear):
print(f"dumping linear layer {name}...")
wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
elif isinstance(module, nn.Conv1d):
print(f"dumping conv1d layer {name}...")
wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
elif isinstance(module, nn.GRU):
print(f"dumping GRU layer {name}...")
wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
elif isinstance(module, nn.GRUCell):
print(f"dumping GRUCell layer {name}...")
wexchange.torch.dump_torch_grucell_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
elif isinstance(module, nn.Embedding):
print(f"dumping Embedding layer {name}...")
wexchange.torch.dump_torch_embedding_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
#wexchange.torch.dump_torch_embedding_weights(writer, module)
else:
print(f"Ignoring layer {name}...")
writer.close()

View File

@@ -0,0 +1,346 @@
import os
import sys
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import filters
from torch.nn.utils import weight_norm
#from convert_lsp import lpc_to_lsp, lsp_to_lpc
from rc import lpc2rc, rc2lpc
source_dir = os.path.split(os.path.abspath(__file__))[0]
sys.path.append(os.path.join(source_dir, "../dnntools"))
from dnntools.quantization import soft_quant
Fs = 16000
fid_dict = {}
def dump_signal(x, filename):
return
if filename in fid_dict:
fid = fid_dict[filename]
else:
fid = open(filename, "w")
fid_dict[filename] = fid
x = x.detach().numpy().astype('float32')
x.tofile(fid)
def sig_l1(y_true, y_pred):
return torch.mean(abs(y_true-y_pred))/torch.mean(abs(y_true))
def sig_loss(y_true, y_pred):
t = y_true/(1e-15+torch.norm(y_true, dim=-1, p=2, keepdim=True))
p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True))
return torch.mean(1.-torch.sum(p*t, dim=-1))
def interp_lpc(lpc, factor):
#print(lpc.shape)
#f = (np.arange(factor)+.5*((factor+1)%2))/factor
lsp = torch.atanh(lpc2rc(lpc))
#print("lsp0:")
#print(lsp)
shape = lsp.shape
#print("shape is", shape)
shape = (shape[0], shape[1]*factor, shape[2])
interp_lsp = torch.zeros(shape, device=lpc.device)
for k in range(factor):
f = (k+.5*((factor+1)%2))/factor
interp = (1-f)*lsp[:,:-1,:] + f*lsp[:,1:,:]
interp_lsp[:,factor//2+k:-(factor//2):factor,:] = interp
for k in range(factor//2):
interp_lsp[:,k,:] = interp_lsp[:,factor//2,:]
for k in range((factor+1)//2):
interp_lsp[:,-k-1,:] = interp_lsp[:,-(factor+3)//2,:]
#print("lsp:")
#print(interp_lsp)
return rc2lpc(torch.tanh(interp_lsp))
def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
device = x.device
batch_size = lpc.size(0)
nb_frames = lpc.shape[1]
sig = torch.zeros(batch_size, subframe_size+16, device=device)
x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size))
out = torch.zeros((batch_size, 0), device=device)
#if gamma is not None:
# bw = gamma**(torch.arange(1, 17, device=device))
# lpc = lpc*bw[None,None,:]
ones = torch.ones((*(lpc.shape[:-1]), 1), device=device)
zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device)
a = torch.cat([ones, lpc], -1)
a_big = torch.cat([a, zeros], -1)
fir_mat_big = filters.toeplitz_from_filter(a_big)
#print(a_big[:,0,:])
for n in range(nb_frames):
for k in range(nb_subframes):
sig = torch.cat([sig[:,subframe_size:], x[:,n*nb_subframes + k, :]], 1)
exc = torch.bmm(fir_mat_big[:,n,:,:], sig[:,:,None])
out = torch.cat([out, exc[:,-subframe_size:,0]], 1)
return out
# weight initialization and clipping
def init_weights(module):
if isinstance(module, nn.GRU):
for p in module.named_parameters():
if p[0].startswith('weight_hh_'):
nn.init.orthogonal_(p[1])
def gen_phase_embedding(periods, frame_size):
device = periods.device
batch_size = periods.size(0)
nb_frames = periods.size(1)
w0 = 2*torch.pi/periods
w0_shift = torch.cat([2*torch.pi*torch.rand((batch_size, 1), device=device)/frame_size, w0[:,:-1]], 1)
cum_phase = frame_size*torch.cumsum(w0_shift, 1)
fine_phase = w0[:,:,None]*torch.broadcast_to(torch.arange(frame_size, device=device), (batch_size, nb_frames, frame_size))
embed = torch.unsqueeze(cum_phase, 2) + fine_phase
embed = torch.reshape(embed, (batch_size, -1))
return torch.cos(embed), torch.sin(embed)
class GLU(nn.Module):
def __init__(self, feat_size, softquant=False):
super(GLU, self).__init__()
torch.manual_seed(5)
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
if softquant:
self.gate = soft_quant(self.gate)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
out = x * torch.sigmoid(self.gate(x))
return out
class FWConv(nn.Module):
def __init__(self, in_size, out_size, kernel_size=2, softquant=False):
super(FWConv, self).__init__()
torch.manual_seed(5)
self.in_size = in_size
self.kernel_size = kernel_size
self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
self.glu = GLU(out_size, softquant=softquant)
if softquant:
self.conv = soft_quant(self.conv)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x, state):
xcat = torch.cat((state, x), -1)
#print(x.shape, state.shape, xcat.shape, self.in_size, self.kernel_size)
out = self.glu(torch.tanh(self.conv(xcat)))
return out, xcat[:,self.in_size:]
def n(x):
return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
class FARGANCond(nn.Module):
def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12, softquant=False):
super(FARGANCond, self).__init__()
self.feature_dim = feature_dim
self.cond_size = cond_size
self.pembed = nn.Embedding(224, pembed_dims)
self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False)
self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
self.fdense2 = nn.Linear(128, 80*4, bias=False)
if softquant:
self.fconv1 = soft_quant(self.fconv1)
self.fdense2 = soft_quant(self.fdense2)
self.apply(init_weights)
nb_params = sum(p.numel() for p in self.parameters())
print(f"cond model: {nb_params} weights")
def forward(self, features, period):
features = features[:,2:,:]
period = period[:,2:]
p = self.pembed(period-32)
features = torch.cat((features, p), -1)
tmp = torch.tanh(self.fdense1(features))
tmp = tmp.permute(0, 2, 1)
tmp = torch.tanh(self.fconv1(tmp))
tmp = tmp.permute(0, 2, 1)
tmp = torch.tanh(self.fdense2(tmp))
#tmp = torch.tanh(self.fdense2(tmp))
return tmp
class FARGANSub(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, softquant=False):
super(FARGANSub, self).__init__()
self.subframe_size = subframe_size
self.nb_subframes = nb_subframes
self.cond_size = cond_size
self.cond_gain_dense = nn.Linear(80, 1)
#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
self.fwc0 = FWConv(2*self.subframe_size+80+4, 192, softquant=softquant)
self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False)
self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False)
self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)
self.gru1_glu = GLU(160, softquant=softquant)
self.gru2_glu = GLU(128, softquant=softquant)
self.gru3_glu = GLU(128, softquant=softquant)
self.skip_glu = GLU(128, softquant=softquant)
#self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False)
self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False)
self.gain_dense_out = nn.Linear(192, 4)
if softquant:
self.gru1 = soft_quant(self.gru1, names=['weight_hh', 'weight_ih'])
self.gru2 = soft_quant(self.gru2, names=['weight_hh', 'weight_ih'])
self.gru3 = soft_quant(self.gru3, names=['weight_hh', 'weight_ih'])
self.skip_dense = soft_quant(self.skip_dense)
self.sig_dense_out = soft_quant(self.sig_dense_out)
self.apply(init_weights)
nb_params = sum(p.numel() for p in self.parameters())
print(f"subframe model: {nb_params} weights")
def forward(self, cond, prev_pred, exc_mem, period, states, gain=None):
device = exc_mem.device
#print(cond.shape, prev.shape)
cond = n(cond)
dump_signal(gain, 'gain0.f32')
gain = torch.exp(self.cond_gain_dense(cond))
dump_signal(gain, 'gain1.f32')
idx = 256-period[:,None]
rng = torch.arange(self.subframe_size+4, device=device)
idx = idx + rng[None,:] - 2
mask = idx >= 256
idx = idx - mask*period[:,None]
pred = torch.gather(exc_mem, 1, idx)
pred = n(pred/(1e-5+gain))
prev = exc_mem[:,-self.subframe_size:]
dump_signal(prev, 'prev_in.f32')
prev = n(prev/(1e-5+gain))
dump_signal(prev, 'pitch_exc.f32')
dump_signal(exc_mem, 'exc_mem.f32')
tmp = torch.cat((cond, pred, prev), 1)
#fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
fpitch = pred[:,2:-2]
#tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
fwc0_out = n(fwc0_out)
pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out))
gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0])
gru1_out = self.gru1_glu(n(gru1_state))
gru1_out = n(gru1_out)
gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1])
gru2_out = self.gru2_glu(n(gru2_state))
gru2_out = n(gru2_out)
gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2])
gru3_out = self.gru3_glu(n(gru3_state))
gru3_out = n(gru3_out)
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1)
skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1)))
skip_out = self.skip_glu(n(skip_out))
sig_out = torch.tanh(self.sig_dense_out(skip_out))
dump_signal(sig_out, 'exc_out.f32')
#taps = self.ptaps_dense(gru3_out)
#taps = .2*taps + torch.exp(taps)
#taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
#dump_signal(taps, 'taps.f32')
dump_signal(pitch_gain, 'pgain.f32')
#sig_out = (sig_out + pitch_gain*fpitch) * gain
sig_out = sig_out * gain
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1)
dump_signal(sig_out, 'sig_out.f32')
return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state)
class FARGAN(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None, softquant=False):
super(FARGAN, self).__init__()
self.subframe_size = subframe_size
self.nb_subframes = nb_subframes
self.frame_size = self.subframe_size*self.nb_subframes
self.feature_dim = feature_dim
self.cond_size = cond_size
self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size, softquant=softquant)
self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, softquant=softquant)
def forward(self, features, period, nb_frames, pre=None, states=None):
device = features.device
batch_size = features.size(0)
prev = torch.zeros(batch_size, 256, device=device)
exc_mem = torch.zeros(batch_size, 256, device=device)
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
states = (
torch.zeros(batch_size, 160, device=device),
torch.zeros(batch_size, 128, device=device),
torch.zeros(batch_size, 128, device=device),
torch.zeros(batch_size, (2*self.subframe_size+80+4)*1, device=device)
)
sig = torch.zeros((batch_size, 0), device=device)
cond = self.cond_net(features, period)
if pre is not None:
exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size]
start = 1 if nb_pre_frames>0 else 0
for n in range(start, nb_frames+nb_pre_frames):
for k in range(self.nb_subframes):
pos = n*self.frame_size + k*self.subframe_size
#print("now: ", preal.shape, prev.shape, sig_in.shape)
pitch = period[:, 3+n]
gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
#gain = gain[:,:,None]
out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain)
if n < nb_pre_frames:
out = pre[:, pos:pos+self.subframe_size]
exc_mem[:,-self.subframe_size:] = out
else:
sig = torch.cat([sig, out], 1)
states = [s.detach() for s in states]
return sig, states

View File

@@ -0,0 +1,46 @@
import torch
from torch import nn
import torch.nn.functional as F
import math
def toeplitz_from_filter(a):
device = a.device
L = a.size(-1)
size0 = (*(a.shape[:-1]), L, L+1)
size = (*(a.shape[:-1]), L, L)
rnge = torch.arange(0, L, dtype=torch.int64, device=device)
z = torch.tensor(0, device=device)
idx = torch.maximum(rnge[:,None] - rnge[None,:] + 1, z)
a = torch.cat([a[...,:1]*0, a], -1)
#print(a)
a = a[...,None,:]
#print(idx)
a = torch.broadcast_to(a, size0)
idx = torch.broadcast_to(idx, size)
#print(idx)
return torch.gather(a, -1, idx)
def filter_iir_response(a, N):
device = a.device
L = a.size(-1)
ar = a.flip(dims=(2,))
size = (*(a.shape[:-1]), N)
R = torch.zeros(size, device=device)
R[:,:,0] = torch.ones((a.shape[:-1]), device=device)
for i in range(1, L):
R[:,:,i] = - torch.sum(ar[:,:,L-i-1:-1] * R[:,:,:i], axis=-1)
#R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,L-i-1:-1], R[:,:,:i])
for i in range(L, N):
R[:,:,i] = - torch.sum(ar[:,:,:-1] * R[:,:,i-L+1:i], axis=-1)
#R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,:-1], R[:,:,i-L+1:i])
return R
if __name__ == '__main__':
#a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]], [[1, .9, 0], [1, .8, 0]]])
a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]]])
A = toeplitz_from_filter(a)
#print(A)
R = filter_iir_response(a, 5)
RA = toeplitz_from_filter(R)
print(RA)

View File

@@ -0,0 +1,29 @@
import torch
def rc2lpc(rc):
order = rc.shape[-1]
lpc=rc[...,0:1]
for i in range(1, order):
lpc = torch.cat([lpc + rc[...,i:i+1]*torch.flip(lpc,dims=(-1,)), rc[...,i:i+1]], -1)
#print("to:", lpc)
return lpc
def lpc2rc(lpc):
order = lpc.shape[-1]
rc = lpc[...,-1:]
for i in range(order-1, 0, -1):
ki = lpc[...,-1:]
lpc = lpc[...,:-1]
lpc = (lpc - ki*torch.flip(lpc,dims=(-1,)))/(1 - ki*ki)
rc = torch.cat([lpc[...,-1:] , rc], -1)
return rc
if __name__ == "__main__":
rc = torch.tensor([[.5, -.5, .6, -.6]])
print(rc)
lpc = rc2lpc(rc)
print(lpc)
rc2 = lpc2rc(lpc)
print(rc2)

View File

@@ -0,0 +1,186 @@
"""STFT-based Loss modules."""
import torch
import torch.nn.functional as F
import numpy as np
import torchaudio
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
#x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
#real = x_stft[..., 0]
#imag = x_stft[..., 1]
# (kan-bayashi): clamp is needed to avoid nan or inf
#return torchaudio.functional.amplitude_to_DB(torch.abs(x_stft),db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
#return torch.clamp(torch.abs(x_stft), min=1e-7)
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
return torch.clamp(torch.abs(x_stft), min=1e-7)
class SpectralConvergenceLoss(torch.nn.Module):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super(SpectralConvergenceLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
x_mag = torch.sqrt(x_mag)
y_mag = torch.sqrt(y_mag)
return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
def __init__(self):
"""Initilize los STFT magnitude loss module."""
super(LogSTFTMagnitudeLoss, self).__init__()
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
#F.l1_loss(torch.sqrt(y_mag), torch.sqrt(x_mag)) +
#F.l1_loss(torchaudio.functional.amplitude_to_DB(y_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80),\
#torchaudio.functional.amplitude_to_DB(x_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80))
#y_mag[:,:y_mag.size(1)//2,:] = y_mag[:,:y_mag.size(1)//2,:] *0.0
#return F.l1_loss(torch.log(y_mag) + torch.sqrt(y_mag), torch.log(x_mag) + torch.sqrt(x_mag))
#return F.l1_loss(y_mag, x_mag)
error_loss = F.l1_loss(y, x) #+ F.l1_loss(torch.sqrt(y), torch.sqrt(x))#F.l1_loss(torch.log(y), torch.log(x))#
#x = torch.log(x)
#y = torch.log(y)
#x = x.permute(0,2,1).contiguous()
#y = y.permute(0,2,1).contiguous()
'''mean_x = torch.mean(x, dim=1, keepdim=True)
mean_y = torch.mean(y, dim=1, keepdim=True)
var_x = torch.var(x, dim=1, keepdim=True)
var_y = torch.var(y, dim=1, keepdim=True)
std_x = torch.std(x, dim=1, keepdim=True)
std_y = torch.std(y, dim=1, keepdim=True)
x_minus_mean = x - mean_x
y_minus_mean = y - mean_y
pearson_corr = torch.sum(x_minus_mean * y_minus_mean, dim=1, keepdim=True) / \
(torch.sqrt(torch.sum(x_minus_mean ** 2, dim=1, keepdim=True) + 1e-7) * \
torch.sqrt(torch.sum(y_minus_mean ** 2, dim=1, keepdim=True) + 1e-7))
numerator = 2.0 * pearson_corr * std_x * std_y
denominator = var_x + var_y + (mean_y - mean_x)**2
ccc = numerator/denominator
ccc_loss = F.l1_loss(1.0 - ccc, torch.zeros_like(ccc))'''
return error_loss #+ ccc_loss#+ ccc_loss
class STFTLoss(torch.nn.Module):
"""STFT loss module."""
def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
"""Initialize STFT loss module."""
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length).to(device)
self.spectral_convergenge_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(torch.nn.Module):
'''def __init__(self,
device,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
hop_sizes=[512, 256, 128, 64, 32, 16],
win_lengths=[2048, 1024, 512, 256, 128, 64],
window="hann_window"):'''
'''def __init__(self,
device,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
hop_sizes=[256, 128, 64, 32, 16, 8],
win_lengths=[1024, 512, 256, 128, 64, 32],
window="hann_window"):'''
def __init__(self,
device,
fft_sizes=[2560, 1280, 640, 320, 160, 80],
hop_sizes=[640, 320, 160, 80, 40, 20],
win_lengths=[2560, 1280, 640, 320, 160, 80],
window="hann_window"):
super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [STFTLoss(device, fs, ss, wl, window)]
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
#mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss #mag_loss #+

View File

@@ -0,0 +1,128 @@
import os
import argparse
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
import fargan
from dataset import FARGANDataset
nb_features = 36
nb_used_features = 20
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='CELPNet model')
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
parser.add_argument('output', type=str, help='path to output file (16-bit PCM)')
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
args = parser.parse_args()
if args.cuda_visible_devices != None:
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
features_file = args.features
signal_file = args.output
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint = torch.load(args.model, map_location='cpu')
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
model.load_state_dict(checkpoint['state_dict'], strict=False)
features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features))
lpc = features[:,4-1:-1,nb_used_features:]
features = features[:, :, :nb_used_features]
#periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
periods = np.round(np.clip(256./2**(features[:,:,nb_used_features-2]+1.5), 32, 255)).astype('int')
nb_frames = features.shape[1]
#nb_frames = 1000
gamma = checkpoint['model_kwargs']['gamma']
def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
out = np.zeros_like(frame)
filt = np.flip(filt)
inp = frame[:]
for i in range(0, inp.shape[0]):
s = inp[i] - np.dot(buffer*weighting_vector, filt)
buffer[0] = s
buffer = np.roll(buffer, -1)
out[i] = s
return out
def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
#inverse perceptual weighting= H_preemph / W(z/gamma)
signal = np.zeros_like(pw_signal)
buffer = np.zeros(16)
num_frames = pw_signal.shape[0] //160
assert num_frames == filters.shape[0]
for frame_idx in range(0, num_frames):
in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
buffer[:] = out_sig_frame[-16:]
return signal
def inverse_perceptual_weighting40 (pw_signal, filters):
#inverse perceptual weighting= H_preemph / W(z/gamma)
signal = np.zeros_like(pw_signal)
buffer = np.zeros(16)
num_frames = pw_signal.shape[0] //40
assert num_frames == filters.shape[0]
for frame_idx in range(0, num_frames):
in_frame = pw_signal[frame_idx*40: (frame_idx+1)*40][:]
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer)
signal[frame_idx*40: (frame_idx+1)*40] = out_sig_frame[:]
buffer[:] = out_sig_frame[-16:]
return signal
from scipy.signal import lfilter
if __name__ == '__main__':
model.to(device)
features = torch.tensor(features).to(device)
#lpc = torch.tensor(lpc).to(device)
periods = torch.tensor(periods).to(device)
weighting = gamma**np.arange(1, 17)
lpc = lpc*weighting
lpc = fargan.interp_lpc(torch.tensor(lpc), 4).numpy()
sig, _ = model(features, periods, nb_frames - 4)
#weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
sig = sig.detach().numpy().flatten()
sig = lfilter(np.array([1.]), np.array([1., -.85]), sig)
#sig = inverse_perceptual_weighting40(sig, lpc[0,:,:])
pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
pcm.tofile(signal_file)

View File

@@ -0,0 +1,169 @@
import os
import argparse
import random
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
import fargan
from dataset import FARGANDataset
from stft_loss import *
parser = argparse.ArgumentParser()
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
parser.add_argument('output', type=str, help='path to output folder')
parser.add_argument('--suffix', type=str, help="model name suffix", default="")
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3)
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20)
training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15)
training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4)
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
args = parser.parse_args()
if args.cuda_visible_devices != None:
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
# checkpoints
checkpoint_dir = os.path.join(args.output, 'checkpoints')
checkpoint = dict()
os.makedirs(checkpoint_dir, exist_ok=True)
# training parameters
batch_size = args.batch_size
lr = args.lr
epochs = args.epochs
sequence_length = args.sequence_length
lr_decay = args.lr_decay
adam_betas = [0.8, 0.95]
adam_eps = 1e-8
features_file = args.features
signal_file = args.signal
# model parameters
cond_size = args.cond_size
checkpoint['batch_size'] = batch_size
checkpoint['lr'] = lr
checkpoint['lr_decay'] = lr_decay
checkpoint['epochs'] = epochs
checkpoint['sequence_length'] = sequence_length
checkpoint['adam_betas'] = adam_betas
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint['model_args'] = ()
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant}
print(checkpoint['model_kwargs'])
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
#model = fargan.FARGAN()
#model = nn.DataParallel(model)
if type(args.initial_checkpoint) != type(None):
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
checkpoint['state_dict'] = model.state_dict()
dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
# learning rate scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
states = None
spect_loss = MultiResolutionSTFTLoss(device).to(device)
if __name__ == '__main__':
model.to(device)
for epoch in range(1, epochs + 1):
running_specc = 0
running_cont_loss = 0
running_loss = 0
print(f"training epoch {epoch}...")
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (features, periods, target, lpc) in enumerate(tepoch):
optimizer.zero_grad()
features = features.to(device)
#lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4))
#print("interp size", lpc.shape)
#lpc = lpc.to(device)
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
#lpc = fargan.interp_lpc(lpc, 4)
periods = periods.to(device)
if (np.random.rand() > 0.1):
target = target[:, :sequence_length*160]
#lpc = lpc[:,:sequence_length*4,:]
features = features[:,:sequence_length+4,:]
periods = periods[:,:sequence_length+4]
else:
target=target[::2, :]
#lpc=lpc[::2,:]
features=features[::2,:]
periods=periods[::2,:]
target = target.to(device)
#print(target.shape, lpc.shape)
#target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
#nb_pre = random.randrange(1, 6)
nb_pre = 2
pre = target[:, :nb_pre*160]
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
sig = torch.cat([pre, sig], -1)
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160])
specc_loss = spect_loss(sig, target.detach())
loss = .03*cont_loss + specc_loss
loss.backward()
optimizer.step()
#model.clip_weights()
scheduler.step()
running_specc += specc_loss.detach().cpu().item()
running_cont_loss += cont_loss.detach().cpu().item()
running_loss += loss.detach().cpu().item()
tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
cont_loss=f"{running_cont_loss/(i+1):8.5f}",
specc=f"{running_specc/(i+1):8.5f}",
)
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth')
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = running_loss / len(dataloader)
checkpoint['epoch'] = epoch
torch.save(checkpoint, checkpoint_path)

View File

@@ -0,0 +1,88 @@
import os
import sys
import argparse
import torch
from torch import nn
sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange'))
import wexchange.torch
from models import model_dict
unquantized = [
'bfcc_with_corr_upsampler.fc',
'cont_net.0',
'fwc6.cont_fc.0',
'fwc6.fc.0',
'fwc6.fc.1.gate',
'fwc7.cont_fc.0',
'fwc7.fc.0',
'fwc7.fc.1.gate'
]
description=f"""
This is an unsafe dumping script for FWGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer
and will fail to export any other weights.
Furthermore, the quanitze option relies on the following explicit list of layers to be excluded:
{unquantized}.
Modify this script manually if adjustments are needed.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name')
parser.add_argument('weightfile', type=str, help='weight file path')
parser.add_argument('export_folder', type=str)
parser.add_argument('--export-filename', type=str, default='fwgan_data', help='filename for source and header file (.c and .h will be added), defaults to fwgan_data')
parser.add_argument('--struct-name', type=str, default='FWGAN', help='name for C struct, defaults to FWGAN')
parser.add_argument('--quantize', action='store_true', help='apply quantization')
if __name__ == "__main__":
args = parser.parse_args()
model = model_dict[args.model]()
print(f"loading weights from {args.weightfile}...")
saved_gen= torch.load(args.weightfile, map_location='cpu')
model.load_state_dict(saved_gen)
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
model.apply(_remove_weight_norm)
print("dumping model...")
quantize_model=args.quantize
output_folder = args.export_folder
os.makedirs(output_folder, exist_ok=True)
writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name)
for name, module in model.named_modules():
if quantize_model:
quantize=name not in unquantized
scale = None if quantize else 1/128
else:
quantize=False
scale=1/128
if isinstance(module, nn.Linear):
print(f"dumping linear layer {name}...")
wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
if isinstance(module, nn.Conv1d):
print(f"dumping conv1d layer {name}...")
wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
if isinstance(module, nn.GRU):
print(f"dumping GRU layer {name}...")
wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
writer.close()

View File

@@ -0,0 +1,141 @@
import os
import time
import torch
import numpy as np
from scipy import signal as si
from scipy.io import wavfile
import argparse
from models import model_dict
parser = argparse.ArgumentParser()
parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name')
parser.add_argument('weightfile', type=str, help='weight file')
parser.add_argument('input', type=str, help='input: feature file or folder with feature files')
parser.add_argument('output', type=str, help='output: wav file name or folder name, depending on input')
########################### Signal Processing Layers ###########################
def preemphasis(x, coef= -0.85):
return si.lfilter(np.array([1.0, coef]), np.array([1.0]), x).astype('float32')
def deemphasis(x, coef= -0.85):
return si.lfilter(np.array([1.0]), np.array([1.0, coef]), x).astype('float32')
gamma = 0.92
weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
out = np.zeros_like(frame)
filt = np.flip(filt)
inp = frame[:]
for i in range(0, inp.shape[0]):
s = inp[i] - np.dot(buffer*weighting_vector, filt)
buffer[0] = s
buffer = np.roll(buffer, -1)
out[i] = s
return out
def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
#inverse perceptual weighting= H_preemph / W(z/gamma)
pw_signal = preemphasis(pw_signal)
signal = np.zeros_like(pw_signal)
buffer = np.zeros(16)
num_frames = pw_signal.shape[0] //160
assert num_frames == filters.shape[0]
for frame_idx in range(0, num_frames):
in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
buffer[:] = out_sig_frame[-16:]
return signal
def process_item(generator, feature_filename, output_filename, verbose=False):
feat = np.memmap(feature_filename, dtype='float32', mode='r')
num_feat_frames = len(feat) // 36
feat = np.reshape(feat, (num_feat_frames, 36))
bfcc = np.copy(feat[:, :18])
corr = np.copy(feat[:, 19:20]) + 0.5
bfcc_with_corr = torch.from_numpy(np.hstack((bfcc, corr))).type(torch.FloatTensor).unsqueeze(0)#.to(device)
period = torch.from_numpy((0.1 + 50 * np.copy(feat[:, 18:19]) + 100)\
.astype('int32')).type(torch.long).view(1,-1)#.to(device)
lpc_filters = np.copy(feat[:, -16:])
start_time = time.time()
x1 = generator(period, bfcc_with_corr, torch.zeros(1,320)) #this means the vocoder runs in complete synthesis mode with zero history audio frames
end_time = time.time()
total_time = end_time - start_time
x1 = x1.squeeze(1).squeeze(0).detach().cpu().numpy()
gen_seconds = len(x1)/16000
out = deemphasis(inverse_perceptual_weighting(x1, lpc_filters, weighting_vector))
if verbose:
print(f"Took {total_time:.3f}s to generate {len(x1)} samples ({gen_seconds}s) -> {gen_seconds/total_time:.2f}x real time")
out = np.clip(np.round(2**15 * out), -2**15, 2**15 -1).astype(np.int16)
wavfile.write(output_filename, 16000, out)
########################### The inference loop over folder containing lpcnet feature files #################################
if __name__ == "__main__":
args = parser.parse_args()
generator = model_dict[args.model]()
#Load the FWGAN500Hz Checkpoint
saved_gen= torch.load(args.weightfile, map_location='cpu')
generator.load_state_dict(saved_gen)
#this is just to remove the weight_norm from the model layers as it's no longer needed
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
generator.apply(_remove_weight_norm)
#enable inference mode
generator = generator.eval()
print('Successfully loaded the generator model ... start generation:')
if os.path.isdir(args.input):
os.makedirs(args.output, exist_ok=True)
for fn in os.listdir(args.input):
print(f"processing input {fn}...")
feature_filename = os.path.join(args.input, fn)
output_filename = os.path.join(args.output, os.path.splitext(fn)[0] + f"_{args.model}.wav")
process_item(generator, feature_filename, output_filename)
else:
process_item(generator, args.input, args.output)
print("Finished!")

View File

@@ -0,0 +1,7 @@
from .fwgan400 import FWGAN400ContLarge
from .fwgan500 import FWGAN500Cont
model_dict = {
'fwgan400': FWGAN400ContLarge,
'fwgan500': FWGAN500Cont
}

View File

@@ -0,0 +1,308 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import numpy as np
which_norm = weight_norm
#################### Definition of basic model components ####################
#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet)
class ConvLookahead(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False):
super(ConvLookahead, self).__init__()
torch.manual_seed(5)
self.padding_left = (kernel_size - 2) * dilation
self.padding_right = 1 * dilation
self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias))
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
x = F.pad(x,(self.padding_left, self.padding_right))
conv_out = self.conv(x)
return conv_out
#(modified) GLU Activation layer definition
class GLU(nn.Module):
def __init__(self, feat_size):
super(GLU, self).__init__()
torch.manual_seed(5)
self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False))
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
out = torch.tanh(x) * torch.sigmoid(self.gate(x))
return out
#GRU layer definition
class ContForwardGRU(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1):
super(ContForwardGRU, self).__init__()
torch.manual_seed(5)
self.hidden_size = hidden_size
self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.hidden_size, bias=False)),
nn.Tanh())
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\
bias=False)
self.nl = GLU(self.hidden_size)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x, x0):
self.gru.flatten_parameters()
h0 = self.cont_fc(x0).unsqueeze(0)
output, h0 = self.gru(x, h0)
return self.nl(output)
# Framewise convolution layer definition
class ContFramewiseConv(torch.nn.Module):
def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True):
super(ContFramewiseConv, self).__init__()
torch.manual_seed(5)
self.frame_kernel_size = frame_kernel_size
self.frame_len = frame_len
if (causal == True) or (self.frame_kernel_size == 2):
self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len
self.required_pad_right = 0
self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.required_pad_left, bias=False)),
nn.Tanh()
)
else:
self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len
self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len
self.fc_input_dim = self.frame_kernel_size * self.frame_len
self.fc_out_dim = out_dim
if act=='glu':
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
GLU(self.fc_out_dim)
)
if act=='tanh':
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
nn.Tanh()
)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x, x0):
if self.frame_kernel_size == 1:
return self.fc(x)
x_flat = x.reshape(x.size(0),1,-1)
pad = self.cont_fc(x0).view(x0.size(0),1,-1)
x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2)
x_flat_padded_unfolded = F.unfold(x_flat_padded,\
kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous()
out = self.fc(x_flat_padded_unfolded)
return out
# A fully-connected based upsampling layer definition
class UpsampleFC(nn.Module):
def __init__(self, in_ch, out_ch, upsample_factor):
super(UpsampleFC, self).__init__()
torch.manual_seed(5)
self.in_ch = in_ch
self.out_ch = out_ch
self.upsample_factor = upsample_factor
self.fc = nn.Linear(in_ch, out_ch * upsample_factor, bias=False)
self.nl = nn.Tanh()
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or\
isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
batch_size = x.size(0)
x = x.permute(0, 2, 1)
x = self.nl(self.fc(x))
x = x.reshape((batch_size, -1, self.out_ch))
x = x.permute(0, 2, 1)
return x
########################### The complete model definition #################################
class FWGAN400ContLarge(nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(5)
self.bfcc_with_corr_upsampler = UpsampleFC(19,80,4)
self.feat_in_conv1 = ConvLookahead(160,256,kernel_size=5)
self.feat_in_nl1 = GLU(256)
self.cont_net = nn.Sequential(which_norm(nn.Linear(321, 160, bias=False)),
nn.Tanh(),
which_norm(nn.Linear(160, 160, bias=False)),
nn.Tanh(),
which_norm(nn.Linear(160, 80, bias=False)),
nn.Tanh(),
which_norm(nn.Linear(80, 80, bias=False)),
nn.Tanh(),
which_norm(nn.Linear(80, 64, bias=False)),
nn.Tanh(),
which_norm(nn.Linear(64, 64, bias=False)),
nn.Tanh())
self.rnn = ContForwardGRU(256,256)
self.fwc1 = ContFramewiseConv(256, 256)
self.fwc2 = ContFramewiseConv(256, 128)
self.fwc3 = ContFramewiseConv(128, 128)
self.fwc4 = ContFramewiseConv(128, 64)
self.fwc5 = ContFramewiseConv(64, 64)
self.fwc6 = ContFramewiseConv(64, 40)
self.fwc7 = ContFramewiseConv(40, 40)
self.init_weights()
self.count_parameters()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def count_parameters(self):
num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n")
def create_phase_signals(self, periods):
batch_size = periods.size(0)
progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
progression = torch.repeat_interleave(progression, batch_size, 0)
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
chunks = []
for sframe in range(periods.size(1)):
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
chunk_sin = torch.sin(f * progression + phase0)
chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,40)
chunk_cos = torch.cos(f * progression + phase0)
chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,40)
chunk = torch.cat((chunk_sin, chunk_cos), dim = -1)
phase0 = phase0 + 160 * f
chunks.append(chunk)
phase_signals = torch.cat(chunks, dim=1)
return phase_signals
def gain_multiply(self, x, c0):
gain = 10**(0.5*c0/np.sqrt(18.0))
gain = torch.repeat_interleave(gain, 160, dim=-1)
gain = gain.reshape(gain.size(0),1,-1).squeeze(1)
return x * gain
def forward(self, pitch_period, bfcc_with_corr, x0):
norm_x0 = torch.norm(x0,2, dim=-1, keepdim=True)
x0 = x0 / torch.sqrt((1e-8) + norm_x0**2)
x0 = torch.cat((torch.log(norm_x0 + 1e-7), x0), dim=-1)
p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous()
envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous())
feat_in = torch.cat((p_embed , envelope), dim=1)
wav_latent1 = self.feat_in_nl1(self.feat_in_conv1(feat_in).permute(0,2,1).contiguous())
cont_latent = self.cont_net(x0)
rnn_out = self.rnn(wav_latent1, cont_latent)
fwc1_out = self.fwc1(rnn_out, cont_latent)
fwc2_out = self.fwc2(fwc1_out, cont_latent)
fwc3_out = self.fwc3(fwc2_out, cont_latent)
fwc4_out = self.fwc4(fwc3_out, cont_latent)
fwc5_out = self.fwc5(fwc4_out, cont_latent)
fwc6_out = self.fwc6(fwc5_out, cont_latent)
fwc7_out = self.fwc7(fwc6_out, cont_latent)
waveform = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1)
waveform = self.gain_multiply(waveform,bfcc_with_corr[:,:,:1])
return waveform

View File

@@ -0,0 +1,260 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import numpy as np
which_norm = weight_norm
#################### Definition of basic model components ####################
#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet)
class ConvLookahead(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False):
super(ConvLookahead, self).__init__()
torch.manual_seed(5)
self.padding_left = (kernel_size - 2) * dilation
self.padding_right = 1 * dilation
self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias))
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
x = F.pad(x,(self.padding_left, self.padding_right))
conv_out = self.conv(x)
return conv_out
#(modified) GLU Activation layer definition
class GLU(nn.Module):
def __init__(self, feat_size):
super(GLU, self).__init__()
torch.manual_seed(5)
self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False))
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
out = torch.tanh(x) * torch.sigmoid(self.gate(x))
return out
#GRU layer definition
class ContForwardGRU(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1):
super(ContForwardGRU, self).__init__()
torch.manual_seed(5)
self.hidden_size = hidden_size
#This is to initialize the layer with history audio samples for continuation.
self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.hidden_size, bias=False)),
nn.Tanh())
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\
bias=False)
self.nl = GLU(self.hidden_size)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x, x0):
self.gru.flatten_parameters()
h0 = self.cont_fc(x0).unsqueeze(0)
output, h0 = self.gru(x, h0)
return self.nl(output)
# Framewise convolution layer definition
class ContFramewiseConv(torch.nn.Module):
def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True):
super(ContFramewiseConv, self).__init__()
torch.manual_seed(5)
self.frame_kernel_size = frame_kernel_size
self.frame_len = frame_len
if (causal == True) or (self.frame_kernel_size == 2):
self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len
self.required_pad_right = 0
#This is to initialize the layer with history audio samples for continuation.
self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.required_pad_left, bias=False)),
nn.Tanh()
)
else:
#This means non-causal frame-wise convolution. We don't use it at the moment
self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len
self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len
self.fc_input_dim = self.frame_kernel_size * self.frame_len
self.fc_out_dim = out_dim
if act=='glu':
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
GLU(self.fc_out_dim)
)
if act=='tanh':
self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
nn.Tanh()
)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x, x0):
if self.frame_kernel_size == 1:
return self.fc(x)
x_flat = x.reshape(x.size(0),1,-1)
pad = self.cont_fc(x0).view(x0.size(0),1,-1)
x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2)
x_flat_padded_unfolded = F.unfold(x_flat_padded,\
kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous()
out = self.fc(x_flat_padded_unfolded)
return out
########################### The complete model definition #################################
class FWGAN500Cont(nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(5)
#PrecondNet:
self.bfcc_with_corr_upsampler = nn.Sequential(nn.ConvTranspose1d(19,64,kernel_size=5,stride=5,padding=0,\
bias=False),
nn.Tanh())
self.feat_in_conv = ConvLookahead(128,256,kernel_size=5)
self.feat_in_nl = GLU(256)
#GRU:
self.rnn = ContForwardGRU(256,256)
#Frame-wise convolution stack:
self.fwc1 = ContFramewiseConv(256, 256)
self.fwc2 = ContFramewiseConv(256, 128)
self.fwc3 = ContFramewiseConv(128, 128)
self.fwc4 = ContFramewiseConv(128, 64)
self.fwc5 = ContFramewiseConv(64, 64)
self.fwc6 = ContFramewiseConv(64, 32)
self.fwc7 = ContFramewiseConv(32, 32, act='tanh')
self.init_weights()
self.count_parameters()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def count_parameters(self):
num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n")
def create_phase_signals(self, periods):
batch_size = periods.size(0)
progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
progression = torch.repeat_interleave(progression, batch_size, 0)
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
chunks = []
for sframe in range(periods.size(1)):
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
chunk_sin = torch.sin(f * progression + phase0)
chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,32)
chunk_cos = torch.cos(f * progression + phase0)
chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,32)
chunk = torch.cat((chunk_sin, chunk_cos), dim = -1)
phase0 = phase0 + 160 * f
chunks.append(chunk)
phase_signals = torch.cat(chunks, dim=1)
return phase_signals
def gain_multiply(self, x, c0):
gain = 10**(0.5*c0/np.sqrt(18.0))
gain = torch.repeat_interleave(gain, 160, dim=-1)
gain = gain.reshape(gain.size(0),1,-1).squeeze(1)
return x * gain
def forward(self, pitch_period, bfcc_with_corr, x0):
#This should create a latent representation of shape [Batch_dim, 500 frames, 256 elemets per frame]
p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous()
envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous())
feat_in = torch.cat((p_embed , envelope), dim=1)
wav_latent = self.feat_in_nl(self.feat_in_conv(feat_in).permute(0,2,1).contiguous())
#Generation with continuation using history samples x0 starts from here:
rnn_out = self.rnn(wav_latent, x0)
fwc1_out = self.fwc1(rnn_out, x0)
fwc2_out = self.fwc2(fwc1_out, x0)
fwc3_out = self.fwc3(fwc2_out, x0)
fwc4_out = self.fwc4(fwc3_out, x0)
fwc5_out = self.fwc5(fwc4_out, x0)
fwc6_out = self.fwc6(fwc5_out, x0)
fwc7_out = self.fwc7(fwc6_out, x0)
waveform_unscaled = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1)
waveform = self.gain_multiply(waveform_unscaled,bfcc_with_corr[:,:,:1])
return waveform

View File

@@ -0,0 +1,27 @@
#Packet loss simulator
This code is an attempt at simulating better packet loss scenarios. The most common way of simulating
packet loss is to use a random sequence where each packet loss event is uncorrelated with previous events.
That is a simplistic model since we know that losses often occur in bursts. This model uses real data
to build a generative model for packet loss.
We use the training data provided for the Audio Deep Packet Loss Concealment Challenge, which is available at:
http://plcchallenge2022pub.blob.core.windows.net/plcchallengearchive/test_train.tar.gz
To create the training data, run:
`./process_data.sh /<path>/test_train/train/lossy_signals/`
That will create an ascii loss\_sorted.txt file with all loss data sorted in increasing packet loss
percentage. Then just run:
`python ./train_lossgen.py`
to train a model
To generate a sequence, run
`python3 ./test_lossgen.py <checkpoint> <percentage> output.txt --length 10000`
where <checkpoint> is the .pth model file and <percentage> is the amount of loss (e.g. 0.2 for 20% loss).

View File

@@ -0,0 +1,101 @@
"""
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', type=str, help='model checkpoint')
parser.add_argument('output_dir', type=str, help='output folder')
args = parser.parse_args()
import torch
import numpy as np
import lossgen
from wexchange.torch import dump_torch_weights
from wexchange.c_export import CWriter, print_vector
def c_export(args, model):
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False, add_typedef=True)
writer.header.write(
f"""
#include "opus_types.h"
"""
)
dense_layers = [
('dense_in', "lossgen_dense_in"),
('dense_out', "lossgen_dense_out")
]
for name, export_name in dense_layers:
layer = model.get_submodule(name)
dump_torch_weights(writer, layer, name=export_name, verbose=True, quantize=False, scale=None)
gru_layers = [
("gru1", "lossgen_gru1"),
("gru2", "lossgen_gru2"),
]
max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=True, scale=None, recurrent_scale=None)
for name, export_name in gru_layers])
writer.header.write(
f"""
#define LOSSGEN_MAX_RNN_UNITS {max_rnn_units}
"""
)
writer.close()
if __name__ == "__main__":
os.makedirs(args.output_dir, exist_ok=True)
checkpoint = torch.load(args.checkpoint, map_location='cpu')
model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
model.load_state_dict(checkpoint['state_dict'], strict=False)
#model = LossGen()
#checkpoint = torch.load(args.checkpoint, map_location='cpu')
#model.load_state_dict(checkpoint['state_dict'])
c_export(args, model)

View File

@@ -0,0 +1,29 @@
import torch
from torch import nn
import torch.nn.functional as F
class LossGen(nn.Module):
def __init__(self, gru1_size=16, gru2_size=16):
super(LossGen, self).__init__()
self.gru1_size = gru1_size
self.gru2_size = gru2_size
self.dense_in = nn.Linear(2, 8)
self.gru1 = nn.GRU(8, self.gru1_size, batch_first=True)
self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True)
self.dense_out = nn.Linear(self.gru2_size, 1)
def forward(self, loss, perc, states=None):
#print(states)
device = loss.device
batch_size = loss.size(0)
if states is None:
gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device)
gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device)
else:
gru1_state = states[0]
gru2_state = states[1]
x = torch.tanh(self.dense_in(torch.cat([loss, perc], dim=-1)))
gru1_out, gru1_state = self.gru1(x, gru1_state)
gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
return self.dense_out(gru2_out), [gru1_state, gru2_state]

View File

@@ -0,0 +1,17 @@
#!/bin/sh
#directory containing the loss files
datadir=$1
for i in $datadir/*_is_lost.txt
do
perc=`cat $i | awk '{a+=$1}END{print a/NR}'`
echo $perc $i
done > percentage_list.txt
sort -n percentage_list.txt | awk '{print $2}' > percentage_sorted.txt
for i in `cat percentage_sorted.txt`
do
cat $i
done > loss_sorted.txt

View File

@@ -0,0 +1,42 @@
import lossgen
import os
import argparse
import torch
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='CELPNet model')
parser.add_argument('percentage', type=float, help='percentage loss')
parser.add_argument('output', type=str, help='path to output file (ascii)')
parser.add_argument('--length', type=int, help="length of sequence to generate", default=500)
args = parser.parse_args()
checkpoint = torch.load(args.model, map_location='cpu')
model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
model.load_state_dict(checkpoint['state_dict'], strict=False)
states=None
last = torch.zeros((1,1,1))
perc = torch.tensor((args.percentage,))[None,None,:]
seq = torch.zeros((0,1,1))
one = torch.ones((1,1,1))
zero = torch.zeros((1,1,1))
if __name__ == '__main__':
for i in range(args.length):
prob, states = model(last, perc, states=states)
prob = torch.sigmoid(prob)
states[0] = states[0].detach()
states[1] = states[1].detach()
loss = one if np.random.rand() < prob else zero
last = loss
seq = torch.cat([seq, loss])
np.savetxt(args.output, seq[:,:,0].numpy().astype('int'), fmt='%d')

View File

@@ -0,0 +1,99 @@
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
from scipy.signal import lfilter
import os
import lossgen
class LossDataset(torch.utils.data.Dataset):
def __init__(self,
loss_file,
sequence_length=997):
self.sequence_length = sequence_length
self.loss = np.loadtxt(loss_file, dtype='float32')
self.nb_sequences = self.loss.shape[0]//self.sequence_length
self.loss = self.loss[:self.nb_sequences*self.sequence_length]
self.perc = lfilter(np.array([.001], dtype='float32'), np.array([1., -.999], dtype='float32'), self.loss)
self.loss = np.reshape(self.loss, (self.nb_sequences, self.sequence_length, 1))
self.perc = np.reshape(self.perc, (self.nb_sequences, self.sequence_length, 1))
def __len__(self):
return self.nb_sequences
def __getitem__(self, index):
r0 = np.random.normal(scale=.1, size=(1,1)).astype('float32')
r1 = np.random.normal(scale=.1, size=(self.sequence_length,1)).astype('float32')
perc = self.perc[index, :, :]
perc = perc + (r0+r1)*perc*(1-perc)
return [self.loss[index, :, :], perc]
adam_betas = [0.8, 0.98]
adam_eps = 1e-8
batch_size=256
lr_decay = 0.001
lr = 0.003
epsilon = 1e-5
epochs = 2000
checkpoint_dir='checkpoint'
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = dict()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint['model_args'] = ()
checkpoint['model_kwargs'] = {'gru1_size': 16, 'gru2_size': 32}
model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
dataset = LossDataset('loss_sorted.txt')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
# learning rate scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
if __name__ == '__main__':
model.to(device)
states = None
for epoch in range(1, epochs + 1):
running_loss = 0
print(f"training epoch {epoch}...")
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (loss, perc) in enumerate(tepoch):
optimizer.zero_grad()
loss = loss.to(device)
perc = perc.to(device)
out, states = model(loss, perc, states=states)
states = [state.detach() for state in states]
out = torch.sigmoid(out[:,:-1,:])
target = loss[:,1:,:]
loss = torch.mean(-target*torch.log(out+epsilon) - (1-target)*torch.log(1-out+epsilon))
loss.backward()
optimizer.step()
scheduler.step()
running_loss += loss.detach().cpu().item()
tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
)
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'lossgen_{epoch}.pth')
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = running_loss / len(dataloader)
checkpoint['epoch'] = epoch
torch.save(checkpoint, checkpoint_path)

View File

@@ -0,0 +1,27 @@
# LPCNet
Incomplete pytorch implementation of LPCNet
## Data preparation
For data preparation use dump_data in github.com/xiph/LPCNet. To turn this into
a training dataset, copy data and feature file to a folder and run
python add_dataset_config.py my_dataset_folder
## Training
To train a model, create and adjust a setup file, e.g. with
python make_default_setup.py my_setup.yml --path2dataset my_dataset_folder
Then simply run
python train_lpcnet.py my_setup.yml my_output
## Inference
Create feature file with dump_data from github.com/xiph/LPCNet. Then run e.g.
python test_lpcnet.py features.f32 my_output/checkpoints/checkpoint_ep_10.pth out.wav
Inference runs on CPU and takes usually between 3 and 20 seconds per generated second of audio,
depending on the CPU.

View File

@@ -0,0 +1,77 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import argparse
import os
import yaml
from utils.templates import dataset_template_v1, dataset_template_v2
parser = argparse.ArgumentParser("add_dataset_config.py")
parser.add_argument('path', type=str, help='path to folder containing feature and data file')
parser.add_argument('--version', type=int, help="dataset version, 1 for classic LPCNet with 55 feature slots, 2 for new format with 36 feature slots.", default=2)
parser.add_argument('--description', type=str, help='brief dataset description', default="I will add a description later")
args = parser.parse_args()
if args.version == 1:
template = dataset_template_v1
data_extension = '.u8'
elif args.version == 2:
template = dataset_template_v2
data_extension = '.s16'
else:
raise ValueError(f"unknown dataset version {args.version}")
# get folder content
content = os.listdir(args.path)
features = [c for c in content if c.endswith('.f32')]
if len(features) != 1:
print("could not determine feature file")
else:
template['feature_file'] = features[0]
data = [c for c in content if c.endswith(data_extension)]
if len(data) != 1:
print("could not determine data file")
else:
template['signal_file'] = data[0]
template['description'] = args.description
with open(os.path.join(args.path, 'info.yml'), 'w') as f:
yaml.dump(template, f)

View File

@@ -0,0 +1 @@
from .lpcnet_dataset import LPCNetDataset

View File

@@ -0,0 +1,227 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" Dataset for LPCNet training """
import os
import yaml
import torch
import numpy as np
from torch.utils.data import Dataset
scale = 255.0/32768.0
scale_1 = 32768.0/255.0
def ulaw2lin(u):
u = u - 128
s = np.sign(u)
u = np.abs(u)
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
def lin2ulaw(x):
s = np.sign(x)
x = np.abs(x)
u = (s*(128*np.log(1+scale*x)/np.log(256)))
u = np.clip(128 + np.round(u), 0, 255)
return u
def run_lpc(signal, lpcs, frame_length=160):
num_frames, lpc_order = lpcs.shape
prediction = np.concatenate(
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
)
error = signal[lpc_order :] - prediction
return prediction, error
class LPCNetDataset(Dataset):
def __init__(self,
path_to_dataset,
features=['cepstrum', 'periods', 'pitch_corr'],
input_signals=['last_signal', 'prediction', 'last_error'],
target='error',
frames_per_sample=15,
feature_history=2,
feature_lookahead=2,
lpc_gamma=1):
super(LPCNetDataset, self).__init__()
# load dataset info
self.path_to_dataset = path_to_dataset
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
dataset = yaml.load(f, yaml.FullLoader)
# dataset version
self.version = dataset['version']
if self.version == 1:
self.getitem = self.getitem_v1
elif self.version == 2:
self.getitem = self.getitem_v2
else:
raise ValueError(f"dataset version {self.version} unknown")
# features
self.feature_history = feature_history
self.feature_lookahead = feature_lookahead
self.frame_offset = 1 + self.feature_history
self.frames_per_sample = frames_per_sample
self.input_features = features
self.feature_frame_layout = dataset['feature_frame_layout']
self.lpc_gamma = lpc_gamma
# load feature file
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
self.feature_frame_length = dataset['feature_frame_length']
assert len(self.features) % self.feature_frame_length == 0
self.features = self.features.reshape((-1, self.feature_frame_length))
# derive number of samples is dataset
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1) // self.frames_per_sample
# signals
self.frame_length = dataset['frame_length']
self.signal_frame_layout = dataset['signal_frame_layout']
self.input_signals = input_signals
self.target = target
# load signals
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
self.signal_frame_length = dataset['signal_frame_length']
self.signals = self.signals.reshape((-1, self.signal_frame_length))
assert len(self.signals) == len(self.features) * self.frame_length
def __getitem__(self, index):
return self.getitem(index)
def getitem_v2(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
# calculate prediction and error if lpc coefficients present and prediction not given
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
# lpc coefficients with one frame lookahead
# frame positions (start one frame early for past excitation)
frame_start = self.frame_offset + self.frames_per_sample * index - 1
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
# feature positions
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
lpc_order = lpc_stop - lpc_start
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
# LPC weighting
lpc_order = lpc_stop - lpc_start
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
lpcs = lpcs * weights
# signal position (lpc_order samples as history)
signal_start = frame_start * self.frame_length - lpc_order + 1
signal_stop = frame_stop * self.frame_length + 1
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
# extract signals
offset = self.frame_length
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
# calculate error between real signal and noisy prediction
sample['error'] = sample['signal'] - sample['prediction']
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
signals = torch.cat([torch.LongTensor(lin2ulaw(sample[key])).unsqueeze(-1) for key in self.input_signals], dim=-1)
target = torch.LongTensor(lin2ulaw(sample[self.target]))
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
def getitem_v1(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
for signal_name, index in self.signal_frame_layout.items():
sample[signal_name] = self.signals[signal_start : signal_stop, index]
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
target = torch.LongTensor(sample[self.target])
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
def __len__(self):
return self.dataset_length

View File

@@ -0,0 +1,141 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from tqdm import tqdm
import sys
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
model.to(device)
model.train()
running_loss = 0
previous_running_loss = 0
# gru states
gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
gru_states = [gru_a_state, gru_b_state]
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# zero out initial gru states
gru_a_state.zero_()
gru_b_state.zero_()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
# calculate loss
loss = criterion(output.permute(0, 2, 1), target)
# calculate gradients
loss.backward()
# update weights
optimizer.step()
# update learning rate
scheduler.step()
# call sparsifier
model.sparsify()
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss
def evaluate(model, criterion, dataloader, device, log_interval=10):
model.to(device)
model.eval()
running_loss = 0
previous_running_loss = 0
# gru states
gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
gru_states = [gru_a_state, gru_b_state]
with torch.no_grad():
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# zero out initial gru states
gru_a_state.zero_()
gru_b_state.zero_()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
# calculate loss
loss = criterion(output.permute(0, 2, 1), target)
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss

View File

@@ -0,0 +1,56 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import argparse
import yaml
from utils.templates import setup_dict
parser = argparse.ArgumentParser()
parser.add_argument('name', type=str, help='name of default setup file')
parser.add_argument('--model', choices=['lpcnet', 'multi_rate'], help='LPCNet model name', default='lpcnet')
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
args = parser.parse_args()
setup = setup_dict[args.model]
# update dataset if given
if type(args.path2dataset) != type(None):
setup['dataset'] = args.path2dataset
name = args.name
if not name.endswith('.yml'):
name += '.yml'
if __name__ == '__main__':
with open(name, 'w') as f:
f.write(yaml.dump(setup))

View File

@@ -0,0 +1,78 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import argparse
import os
import sys
parser = argparse.ArgumentParser()
parser.add_argument("config_name", type=str, help="name of config file (.yml will be appended)")
parser.add_argument("test_name", type=str, help="name for test result display")
parser.add_argument("checkpoint", type=str, help="checkpoint to test")
parser.add_argument("--lpcnet-demo", type=str, help="path to lpcnet_demo binary, default: /local/code/LPCNet/lpcnet_demo", default="/local/code/LPCNet/lpcnet_demo")
parser.add_argument("--lpcnext-path", type=str, help="path to lpcnext folder, defalut: dirname(__file__)", default=os.path.dirname(__file__))
parser.add_argument("--python-exe", type=str, help='python executable path, default: sys.executable', default=sys.executable)
parser.add_argument("--pad", type=str, help="left pad of output in seconds, default: 0.015", default="0.015")
parser.add_argument("--trim", type=str, help="left trim of output in seconds, default: 0", default="0")
template='''
test: "{NAME}"
processing:
- "sox {{INPUT}} {{INPUT}}.raw"
- "{LPCNET_DEMO} -features {{INPUT}}.raw {{INPUT}}.features.f32"
- "{PYTHON} {WORKING}/test_lpcnet.py {{INPUT}}.features.f32 {CHECKPOINT} {{OUTPUT}}.ua.wav"
- "sox {{OUTPUT}}.ua.wav {{OUTPUT}}.uap.wav pad {PAD}"
- "sox {{OUTPUT}}.uap.wav {{OUTPUT}} trim {TRIM}"
- "rm {{INPUT}}.raw {{OUTPUT}}.uap.wav {{OUTPUT}}.ua.wav {{INPUT}}.features.f32"
'''
if __name__ == "__main__":
args = parser.parse_args()
file_content = template.format(
NAME=args.test_name,
LPCNET_DEMO=os.path.abspath(args.lpcnet_demo),
PYTHON=os.path.abspath(args.python_exe),
PAD=args.pad,
TRIM=args.trim,
WORKING=os.path.abspath(args.lpcnext_path),
CHECKPOINT=os.path.abspath(args.checkpoint)
)
print(file_content)
filename = args.config_name
if not filename.endswith(".yml"):
filename += ".yml"
with open(filename, "w") as f:
f.write(file_content)

View File

@@ -0,0 +1,8 @@
from .lpcnet import LPCNet
from .multi_rate_lpcnet import MultiRateLPCNet
model_dict = {
'lpcnet' : LPCNet,
'multi_rate' : MultiRateLPCNet
}

View File

@@ -0,0 +1,303 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import numpy as np
from utils.ulaw import lin2ulawq, ulaw2lin
from utils.sample import sample_excitation
from utils.pcm import clip_to_int16
from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
from utils.layers import DualFC
from utils.misc import get_pdf_from_tree
class LPCNet(nn.Module):
def __init__(self, config):
super(LPCNet, self).__init__()
#
self.input_layout = config['input_layout']
self.feature_history = config['feature_history']
self.feature_lookahead = config['feature_lookahead']
# frame rate network parameters
self.feature_dimension = config['feature_dimension']
self.period_embedding_dim = config['period_embedding_dim']
self.period_levels = config['period_levels']
self.feature_channels = self.feature_dimension + self.period_embedding_dim
self.feature_conditioning_dim = config['feature_conditioning_dim']
self.feature_conv_kernel_size = config['feature_conv_kernel_size']
# frame rate network layers
self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
# sample rate network parameters
self.frame_size = config['frame_size']
self.signal_levels = config['signal_levels']
self.signal_embedding_dim = config['signal_embedding_dim']
self.gru_a_units = config['gru_a_units']
self.gru_b_units = config['gru_b_units']
self.output_levels = config['output_levels']
self.hsampling = config.get('hsampling', False)
self.gru_a_input_dim = len(self.input_layout['signals']) * self.signal_embedding_dim + self.feature_conditioning_dim
self.gru_b_input_dim = self.gru_a_units + self.feature_conditioning_dim
# sample rate network layers
self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
self.dual_fc = DualFC(self.gru_b_units, self.output_levels)
# sparsification
self.sparsifier = []
# GRU A
if 'gru_a' in config['sparsification']:
gru_config = config['sparsification']['gru_a']
task_list = [(self.gru_a, gru_config['params'])]
self.sparsifier.append(GRUSparsifier(task_list,
gru_config['start'],
gru_config['stop'],
gru_config['interval'],
gru_config['exponent'])
)
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
gru_config['params'], drop_input=True)
else:
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
# GRU B
if 'gru_b' in config['sparsification']:
gru_config = config['sparsification']['gru_b']
task_list = [(self.gru_b, gru_config['params'])]
self.sparsifier.append(GRUSparsifier(task_list,
gru_config['start'],
gru_config['stop'],
gru_config['interval'],
gru_config['exponent'])
)
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
gru_config['params'])
else:
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
# inference parameters
self.lpc_gamma = config.get('lpc_gamma', 1)
def sparsify(self):
for sparsifier in self.sparsifier:
sparsifier.step()
def get_gflops(self, fs, verbose=False):
gflops = 0
# frame rate network
conditioning_dim = self.feature_conditioning_dim
feature_channels = self.feature_channels
frame_rate = fs / self.frame_size
frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
if verbose:
print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
gflops += frame_rate_network_complexity
# gru a
gru_a_rate = fs
gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
if verbose:
print(f"gru A: {gru_a_complexity} GFLOPS")
gflops += gru_a_complexity
# gru b
gru_b_rate = fs
gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
if verbose:
print(f"gru B: {gru_b_complexity} GFLOPS")
gflops += gru_b_complexity
# dual fcs
fc = self.dual_fc
rate = fs
input_size = fc.dense1.in_features
output_size = fc.dense1.out_features
dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
if self.hsampling:
dual_fc_complexity /= 8
if verbose:
print(f"dual_fc: {dual_fc_complexity} GFLOPS")
gflops += dual_fc_complexity
if verbose:
print(f'total: {gflops} GFLOPS')
return gflops
def frame_rate_network(self, features, periods):
embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
features = torch.concat((features, embedded_periods), dim=-1)
# convert to channels first and calculate conditioning vector
c = torch.permute(features, [0, 2, 1])
c = torch.tanh(self.feature_conv1(c))
c = torch.tanh(self.feature_conv2(c))
# back to channels last
c = torch.permute(c, [0, 2, 1])
c = torch.tanh(self.feature_dense1(c))
c = torch.tanh(self.feature_dense2(c))
return c
def sample_rate_network(self, signals, c, gru_states):
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
c_upsampled = torch.repeat_interleave(c, self.frame_size, dim=1)
y = torch.concat((embedded_signals, c_upsampled), dim=-1)
y, gru_a_state = self.gru_a(y, gru_states[0])
y = torch.concat((y, c_upsampled), dim=-1)
y, gru_b_state = self.gru_b(y, gru_states[1])
y = self.dual_fc(y)
if self.hsampling:
y = torch.sigmoid(y)
log_probs = torch.log(get_pdf_from_tree(y) + 1e-6)
else:
log_probs = torch.log_softmax(y, dim=-1)
return log_probs, (gru_a_state, gru_b_state)
def decoder(self, signals, c, gru_states):
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
y = torch.concat((embedded_signals, c), dim=-1)
y, gru_a_state = self.gru_a(y, gru_states[0])
y = torch.concat((y, c), dim=-1)
y, gru_b_state = self.gru_b(y, gru_states[1])
y = self.dual_fc(y)
if self.hsampling:
y = torch.sigmoid(y)
probs = get_pdf_from_tree(y)
else:
probs = torch.softmax(y, dim=-1)
return probs, (gru_a_state, gru_b_state)
def forward(self, features, periods, signals, gru_states):
c = self.frame_rate_network(features, periods)
log_probs, _ = self.sample_rate_network(signals, c, gru_states)
return log_probs
def generate(self, features, periods, lpcs):
with torch.no_grad():
device = self.parameters().__next__().device
num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
lpc_order = lpcs.shape[-1]
num_input_signals = len(self.input_layout['signals'])
pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
# signal buffers
pcm = torch.zeros((num_frames * self.frame_size + lpc_order))
output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
mem = 0
# state buffers
gru_a_state = torch.zeros((1, 1, self.gru_a_units))
gru_b_state = torch.zeros((1, 1, self.gru_b_units))
gru_states = [gru_a_state, gru_b_state]
input_signals = torch.zeros((1, 1, num_input_signals), dtype=torch.long) + 128
# push data to device
features = features.to(device)
periods = periods.to(device)
lpcs = lpcs.to(device)
# lpc weighting
weights = torch.FloatTensor([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]).to(device)
lpcs = lpcs * weights
# run feature encoding
c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
for frame_index in range(num_frames):
frame_start = frame_index * self.frame_size
pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
current_c = c[:, frame_index : frame_index + 1, :]
for i in range(self.frame_size):
pcm_position = frame_start + i + lpc_order
output_position = frame_start + i
# prepare input
pred = torch.sum(pcm[pcm_position - lpc_order : pcm_position] * a)
if 'prediction' in self.input_layout['signals']:
input_signals[0, 0, self.input_layout['signals']['prediction']] = lin2ulawq(pred)
# run single step of sample rate network
probs, gru_states = self.decoder(
input_signals,
current_c,
gru_states
)
# sample from output
exc_ulaw = sample_excitation(probs, pitch_corr)
# signal generation
exc = ulaw2lin(exc_ulaw)
sig = exc + pred
pcm[pcm_position] = sig
mem = 0.85 * mem + float(sig)
output[output_position] = clip_to_int16(round(mem))
# buffer update
if 'last_signal' in self.input_layout['signals']:
input_signals[0, 0, self.input_layout['signals']['last_signal']] = lin2ulawq(sig)
if 'last_error' in self.input_layout['signals']:
input_signals[0, 0, self.input_layout['signals']['last_error']] = lin2ulawq(exc)
return output

View File

@@ -0,0 +1,437 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
from utils.layers.subconditioner import get_subconditioner
from utils.layers import DualFC
from utils.ulaw import lin2ulawq, ulaw2lin
from utils.sample import sample_excitation
from utils.pcm import clip_to_int16
from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
from utils.misc import interleave_tensors
# MultiRateLPCNet
class MultiRateLPCNet(nn.Module):
def __init__(self, config):
super(MultiRateLPCNet, self).__init__()
# general parameters
self.input_layout = config['input_layout']
self.feature_history = config['feature_history']
self.feature_lookahead = config['feature_lookahead']
self.signals = config['signals']
# frame rate network parameters
self.feature_dimension = config['feature_dimension']
self.period_embedding_dim = config['period_embedding_dim']
self.period_levels = config['period_levels']
self.feature_channels = self.feature_dimension + self.period_embedding_dim
self.feature_conditioning_dim = config['feature_conditioning_dim']
self.feature_conv_kernel_size = config['feature_conv_kernel_size']
# frame rate network layers
self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
# sample rate network parameters
self.frame_size = config['frame_size']
self.signal_levels = config['signal_levels']
self.signal_embedding_dim = config['signal_embedding_dim']
self.gru_a_units = config['gru_a_units']
self.gru_b_units = config['gru_b_units']
self.output_levels = config['output_levels']
# subconditioning B
sub_config = config['subconditioning']['subconditioning_b']
self.substeps_b = sub_config['number_of_subsamples']
self.subcondition_signals_b = sub_config['signals']
self.signals_idx_b = [self.input_layout['signals'][key] for key in sub_config['signals']]
method = sub_config['method']
kwargs = sub_config['kwargs']
if type(kwargs) == type(None):
kwargs = dict()
state_size = self.gru_b_units
self.subconditioner_b = get_subconditioner(method,
sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
state_size, self.signal_levels, len(sub_config['signals']),
**sub_config['kwargs'])
# subconditioning A
sub_config = config['subconditioning']['subconditioning_a']
self.substeps_a = sub_config['number_of_subsamples']
self.subcondition_signals_a = sub_config['signals']
self.signals_idx_a = [self.input_layout['signals'][key] for key in sub_config['signals']]
method = sub_config['method']
kwargs = sub_config['kwargs']
if type(kwargs) == type(None):
kwargs = dict()
state_size = self.gru_a_units
self.subconditioner_a = get_subconditioner(method,
sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
state_size, self.signal_levels, self.substeps_b * len(sub_config['signals']),
**sub_config['kwargs'])
# wrap up subconditioning, group_size_gru_a holds the number
# of timesteps that are grouped as sample input for GRU A
# input and group_size_subcondition_a holds the number of samples that are
# grouped as input to pre-GRU B subconditioning
self.group_size_gru_a = self.substeps_a * self.substeps_b
self.group_size_subcondition_a = self.substeps_b
self.gru_a_rate_divider = self.group_size_gru_a
self.gru_b_rate_divider = self.substeps_b
# gru sizes
self.gru_a_input_dim = self.group_size_gru_a * len(self.signals) * self.signal_embedding_dim + self.feature_conditioning_dim
self.gru_b_input_dim = self.subconditioner_a.get_output_dim(0) + self.feature_conditioning_dim
self.signals_idx = [self.input_layout['signals'][key] for key in self.signals]
# sample rate network layers
self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
# sparsification
self.sparsifier = []
# GRU A
if 'gru_a' in config['sparsification']:
gru_config = config['sparsification']['gru_a']
task_list = [(self.gru_a, gru_config['params'])]
self.sparsifier.append(GRUSparsifier(task_list,
gru_config['start'],
gru_config['stop'],
gru_config['interval'],
gru_config['exponent'])
)
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
gru_config['params'], drop_input=True)
else:
self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
# GRU B
if 'gru_b' in config['sparsification']:
gru_config = config['sparsification']['gru_b']
task_list = [(self.gru_b, gru_config['params'])]
self.sparsifier.append(GRUSparsifier(task_list,
gru_config['start'],
gru_config['stop'],
gru_config['interval'],
gru_config['exponent'])
)
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
gru_config['params'])
else:
self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
# dual FCs
self.dual_fc = []
for i in range(self.substeps_b):
dim = self.subconditioner_b.get_output_dim(i)
self.dual_fc.append(DualFC(dim, self.output_levels))
self.add_module(f"dual_fc_{i}", self.dual_fc[-1])
def get_gflops(self, fs, verbose=False, hierarchical_sampling=False):
gflops = 0
# frame rate network
conditioning_dim = self.feature_conditioning_dim
feature_channels = self.feature_channels
frame_rate = fs / self.frame_size
frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
if verbose:
print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
gflops += frame_rate_network_complexity
# gru a
gru_a_rate = fs / self.group_size_gru_a
gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
if verbose:
print(f"gru A: {gru_a_complexity} GFLOPS")
gflops += gru_a_complexity
# subconditioning a
subcond_a_rate = fs / self.substeps_b
subconditioning_a_complexity = 1e-9 * self.subconditioner_a.get_average_flops_per_step() * subcond_a_rate
if verbose:
print(f"subconditioning A: {subconditioning_a_complexity} GFLOPS")
gflops += subconditioning_a_complexity
# gru b
gru_b_rate = fs / self.substeps_b
gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
if verbose:
print(f"gru B: {gru_b_complexity} GFLOPS")
gflops += gru_b_complexity
# subconditioning b
subcond_b_rate = fs
subconditioning_b_complexity = 1e-9 * self.subconditioner_b.get_average_flops_per_step() * subcond_b_rate
if verbose:
print(f"subconditioning B: {subconditioning_b_complexity} GFLOPS")
gflops += subconditioning_b_complexity
# dual fcs
for i, fc in enumerate(self.dual_fc):
rate = fs / len(self.dual_fc)
input_size = fc.dense1.in_features
output_size = fc.dense1.out_features
dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
if hierarchical_sampling:
dual_fc_complexity /= 8
if verbose:
print(f"dual_fc_{i}: {dual_fc_complexity} GFLOPS")
gflops += dual_fc_complexity
if verbose:
print(f'total: {gflops} GFLOPS')
return gflops
def sparsify(self):
for sparsifier in self.sparsifier:
sparsifier.step()
def frame_rate_network(self, features, periods):
embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
features = torch.concat((features, embedded_periods), dim=-1)
# convert to channels first and calculate conditioning vector
c = torch.permute(features, [0, 2, 1])
c = torch.tanh(self.feature_conv1(c))
c = torch.tanh(self.feature_conv2(c))
# back to channels last
c = torch.permute(c, [0, 2, 1])
c = torch.tanh(self.feature_dense1(c))
c = torch.tanh(self.feature_dense2(c))
return c
def prepare_signals(self, signals, group_size, signal_idx):
""" extracts, delays and groups signals """
batch_size, sequence_length, num_signals = signals.shape
# extract signals according to position
signals = torch.cat([signals[:, :, i : i + 1] for i in signal_idx],
dim=-1)
# roll back pcm to account for grouping
signals = torch.roll(signals, group_size - 1, -2)
# reshape
signals = torch.reshape(signals,
(batch_size, sequence_length // group_size, group_size * len(signal_idx)))
return signals
def sample_rate_network(self, signals, c, gru_states):
signals_a = self.prepare_signals(signals, self.group_size_gru_a, self.signals_idx)
embedded_signals = torch.flatten(self.signal_embedding(signals_a), 2, 3)
# features at GRU A rate
c_upsampled_a = torch.repeat_interleave(c, self.frame_size // self.gru_a_rate_divider, dim=1)
# features at GRU B rate
c_upsampled_b = torch.repeat_interleave(c, self.frame_size // self.gru_b_rate_divider, dim=1)
y = torch.concat((embedded_signals, c_upsampled_a), dim=-1)
y, gru_a_state = self.gru_a(y, gru_states[0])
# first round of upsampling and subconditioning
c_signals_a = self.prepare_signals(signals, self.group_size_subcondition_a, self.signals_idx_a)
y = self.subconditioner_a(y, c_signals_a)
y = interleave_tensors(y)
y = torch.concat((y, c_upsampled_b), dim=-1)
y, gru_b_state = self.gru_b(y, gru_states[1])
c_signals_b = self.prepare_signals(signals, 1, self.signals_idx_b)
y = self.subconditioner_b(y, c_signals_b)
y = [self.dual_fc[i](y[i]) for i in range(self.substeps_b)]
y = interleave_tensors(y)
return y, (gru_a_state, gru_b_state)
def decoder(self, signals, c, gru_states):
embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
y = torch.concat((embedded_signals, c), dim=-1)
y, gru_a_state = self.gru_a(y, gru_states[0])
y = torch.concat((y, c), dim=-1)
y, gru_b_state = self.gru_b(y, gru_states[1])
y = self.dual_fc(y)
return torch.softmax(y, dim=-1), (gru_a_state, gru_b_state)
def forward(self, features, periods, signals, gru_states):
c = self.frame_rate_network(features, periods)
y, _ = self.sample_rate_network(signals, c, gru_states)
log_probs = torch.log_softmax(y, dim=-1)
return log_probs
def generate(self, features, periods, lpcs):
with torch.no_grad():
device = self.parameters().__next__().device
num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
lpc_order = lpcs.shape[-1]
num_input_signals = len(self.signals)
pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
# signal buffers
last_signal = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
prediction = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
last_error = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
mem = 0
# state buffers
gru_a_state = torch.zeros((1, 1, self.gru_a_units))
gru_b_state = torch.zeros((1, 1, self.gru_b_units))
input_signals = 128 + torch.zeros(self.group_size_gru_a * num_input_signals, dtype=torch.long)
# conditioning signals for subconditioner a
c_signals_a = 128 + torch.zeros(self.group_size_subcondition_a * len(self.signals_idx_a), dtype=torch.long)
# conditioning signals for subconditioner b
c_signals_b = 128 + torch.zeros(len(self.signals_idx_b), dtype=torch.long)
# signal dict
signal_dict = {
'prediction' : prediction,
'last_error' : last_error,
'last_signal' : last_signal
}
# push data to device
features = features.to(device)
periods = periods.to(device)
lpcs = lpcs.to(device)
# run feature encoding
c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
for frame_index in range(num_frames):
frame_start = frame_index * self.frame_size
pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
current_c = c[:, frame_index : frame_index + 1, :]
for i in range(0, self.frame_size, self.group_size_gru_a):
pcm_position = frame_start + i + lpc_order
output_position = frame_start + i
# calculate newest prediction
prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
# prepare input
for slot in range(self.group_size_gru_a):
k = slot - self.group_size_gru_a + 1
for idx, name in enumerate(self.signals):
input_signals[idx + slot * num_input_signals] = lin2ulawq(
signal_dict[name][pcm_position + k]
)
# run GRU A
embed_signals = self.signal_embedding(input_signals.reshape((1, 1, -1)))
embed_signals = torch.flatten(embed_signals, 2)
y = torch.cat((embed_signals, current_c), dim=-1)
h_a, gru_a_state = self.gru_a(y, gru_a_state)
# loop over substeps_a
for step_a in range(self.substeps_a):
# prepare conditioning input
for slot in range(self.group_size_subcondition_a):
k = slot - self.group_size_subcondition_a + 1
for idx, name in enumerate(self.subcondition_signals_a):
c_signals_a[idx + slot * num_input_signals] = lin2ulawq(
signal_dict[name][pcm_position + k]
)
# subconditioning
h_a = self.subconditioner_a.single_step(step_a, h_a, c_signals_a.reshape((1, 1, -1)))
# run GRU B
y = torch.cat((h_a, current_c), dim=-1)
h_b, gru_b_state = self.gru_b(y, gru_b_state)
# loop over substeps b
for step_b in range(self.substeps_b):
# prepare subconditioning input
for idx, name in enumerate(self.subcondition_signals_b):
c_signals_b[idx] = lin2ulawq(
signal_dict[name][pcm_position]
)
# subcondition
h_b = self.subconditioner_b.single_step(step_b, h_b, c_signals_b.reshape((1, 1, -1)))
# run dual FC
probs = torch.softmax(self.dual_fc[step_b](h_b), dim=-1)
# sample
new_exc = ulaw2lin(sample_excitation(probs, pitch_corr))
# update signals
sig = new_exc + prediction[pcm_position]
last_error[pcm_position + 1] = new_exc
last_signal[pcm_position + 1] = sig
mem = 0.85 * mem + float(sig)
output[output_position] = clip_to_int16(round(mem))
# increase positions
pcm_position += 1
output_position += 1
# calculate next prediction
prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
return output

View File

@@ -0,0 +1,64 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import argparse
import yaml
from models import model_dict
debug = False
if debug:
args = type('dummy', (object,),
{
'setup' : 'setups/lpcnet_m/setup_1_4_concatenative.yml',
'hierarchical_sampling' : False
})()
else:
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('--hierarchical-sampling', action="store_true", help='whether to assume hierarchical sampling (default=False)', default=False)
args = parser.parse_args()
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
# check model
if not 'model' in setup['lpcnet']:
print(f'warning: did not find model entry in setup, using default lpcnet')
model_name = 'lpcnet'
else:
model_name = setup['lpcnet']['model']
# create model
model = model_dict[model_name](setup['lpcnet']['config'])
gflops = model.get_gflops(16000, verbose=True, hierarchical_sampling=args.hierarchical_sampling)

View File

@@ -0,0 +1,190 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import argparse
import os
from uuid import UUID
from collections import OrderedDict
import pickle
import torch
import numpy as np
import utils
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="input folder containing multi-run output")
parser.add_argument("tag", type=str, help="tag for multi-run experiment")
parser.add_argument("csv", type=str, help="name for output csv")
def is_uuid(val):
try:
UUID(val)
return True
except:
return False
def collect_results(folder):
training_folder = os.path.join(folder, 'training')
testing_folder = os.path.join(folder, 'testing')
# validation loss
checkpoint = torch.load(os.path.join(training_folder, 'checkpoints', 'checkpoint_finalize_epoch_1.pth'), map_location='cpu')
validation_loss = checkpoint['validation_loss']
# eval_warpq
eval_warpq = utils.data.parse_warpq_scores(os.path.join(training_folder, 'out_finalize.txt'))[-1]
# testing results
testing_results = utils.data.collect_test_stats(os.path.join(testing_folder, 'final'))
results = OrderedDict()
results['eval_loss'] = validation_loss
results['eval_warpq'] = eval_warpq
results['pesq_mean'] = testing_results['pesq'][0]
results['warpq_mean'] = testing_results['warpq'][0]
results['pitch_error_mean'] = testing_results['pitch_error'][0]
results['voicing_error_mean'] = testing_results['voicing_error'][0]
return results
def print_csv(path, results, tag, ranks=None, header=True):
metrics = next(iter(results.values())).keys()
if ranks is not None:
rank_keys = next(iter(ranks.values())).keys()
else:
rank_keys = []
with open(path, 'w') as f:
if header:
f.write("uuid, tag")
for metric in metrics:
f.write(f", {metric}")
for rank in rank_keys:
f.write(f", {rank}")
f.write("\n")
for uuid, values in results.items():
f.write(f"{uuid}, {tag}")
for val in values.values():
f.write(f", {val:10.8f}")
for rank in rank_keys:
f.write(f", {ranks[uuid][rank]:4d}")
f.write("\n")
def get_ranks(results):
metrics = list(next(iter(results.values())).keys())
positive = {'pesq_mean', 'mix'}
ranks = OrderedDict()
for key in results.keys():
ranks[key] = OrderedDict()
for metric in metrics:
sign = -1 if metric in positive else 1
x = sorted([(key, value[metric]) for key, value in results.items()], key=lambda x: sign * x[1])
x = [y[0] for y in x]
for key in results.keys():
ranks[key]['rank_' + metric] = x.index(key) + 1
return ranks
def analyse_metrics(results):
metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
x = []
for metric in metrics:
x.append([val[metric] for val in results.values()])
x = np.array(x)
print(x)
def add_mix_metric(results):
metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
x = []
for metric in metrics:
x.append([val[metric] for val in results.values()])
x = np.array(x).transpose() * np.array([-1, 1, -1, -1, -1])
z = (x - np.mean(x, axis=0)) / np.std(x, axis=0)
print(f"covariance matrix for normalized scores of {metrics}:")
print(np.cov(z.transpose()))
score = np.mean(z, axis=1)
for i, key in enumerate(results.keys()):
results[key]['mix'] = score[i].item()
if __name__ == "__main__":
args = parser.parse_args()
uuids = sorted([x for x in os.listdir(args.input) if os.path.isdir(os.path.join(args.input, x)) and is_uuid(x)])
results = OrderedDict()
for uuid in uuids:
results[uuid] = collect_results(os.path.join(args.input, uuid))
add_mix_metric(results)
ranks = get_ranks(results)
csv = args.csv if args.csv.endswith('.csv') else args.csv + '.csv'
print_csv(args.csv, results, args.tag, ranks=ranks)
with open(csv[:-4] + '.pickle', 'wb') as f:
pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)

View File

@@ -0,0 +1,52 @@
#!/bin/bash
case $# in
9) SETUP=$1; OUTDIR=$2; NAME=$3; DEVICE=$4; ROUNDS=$5; LPCNEXT=$6; LPCNET=$7; TESTSUITE=$8; TESTITEMS=$9;;
*) echo "loop_run.sh setup outdir name device rounds lpcnext_repo lpcnet_repo testsuite_repo testitems"; exit;;
esac
PYTHON="/home/ubuntu/opt/miniconda3/envs/torch/bin/python"
TESTFEATURES=${LPCNEXT}/testitems/features/all_0_orig_features.f32
WARPQREFERENCE=${LPCNEXT}/testitems/wav/all_0_orig.wav
METRICS="warpq,pesq,pitch_error,voicing_error"
LPCNETDEMO=${LPCNET}/lpcnet_demo
for ((round = 1; round <= $ROUNDS; round++))
do
echo
echo round $round
UUID=$(uuidgen)
TRAINOUT=${OUTDIR}/${UUID}/training
TESTOUT=${OUTDIR}/${UUID}/testing
CHECKPOINT=${TRAINOUT}/checkpoints/checkpoint_last.pth
FINALCHECKPOINT=${TRAINOUT}/checkpoints/checkpoint_finalize_last.pth
# run training
echo "starting training..."
$PYTHON $LPCNEXT/train_lpcnet.py $SETUP $TRAINOUT --device $DEVICE --test-features $TESTFEATURES --warpq-reference $WARPQREFERENCE
# run finalization
echo "starting finalization..."
$PYTHON $LPCNEXT/train_lpcnet.py $SETUP $TRAINOUT \
--device $DEVICE --test-features $TESTFEATURES \
--warpq-reference $WARPQREFERENCE \
--finalize --initial-checkpoint $CHECKPOINT
# create test configs
$PYTHON $LPCNEXT/make_test_config.py ${OUTDIR}/${UUID}/testconfig.yml "$NAME $UUID" $CHECKPOINT --lpcnet-demo $LPCNETDEMO
$PYTHON $LPCNEXT/make_test_config.py ${OUTDIR}/${UUID}/testconfig_finalize.yml "$NAME $UUID finalized" $FINALCHECKPOINT --lpcnet-demo $LPCNETDEMO
# run tests
echo "starting test 1 (no finalization)..."
$PYTHON $TESTSUITE/run_test.py ${OUTDIR}/${UUID}/testconfig.yml \
$TESTITEMS ${TESTOUT}/prefinal --num-workers 8 \
--num-testitems 400 --metrics $METRICS
echo "starting test 2 (after finalization)..."
$PYTHON $TESTSUITE/run_test.py ${OUTDIR}/${UUID}/testconfig_finalize.yml \
$TESTITEMS ${TESTOUT}/final --num-workers 8 \
--num-testitems 400 --metrics $METRICS
done

View File

@@ -0,0 +1,67 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" script for creating animations from debug data
"""
import argparse
import sys
sys.path.append('./')
from utils.endoscopy import make_animation, read_data
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='endoscopy folder with debug output')
parser.add_argument('output', type=str, help='output file (will be auto-extended with .mp4)')
parser.add_argument('--start-index', type=int, help='index of first sample to be considered', default=0)
parser.add_argument('--stop-index', type=int, help='index of last sample to be considered', default=-1)
parser.add_argument('--interval', type=int, help='interval between frames in ms', default=20)
parser.add_argument('--half-window-length', type=int, help='half size of window for displaying signals', default=80)
if __name__ == "__main__":
args = parser.parse_args()
filename = args.output if args.output.endswith('.mp4') else args.output + '.mp4'
data = read_data(args.folder)
make_animation(
data,
filename,
start_index=args.start_index,
stop_index = args.stop_index,
half_signal_window_length=args.half_window_length
)

View File

@@ -0,0 +1,17 @@
import argparse
import numpy as np
parser = argparse.ArgumentParser(description="sets s_t to augmented_s_t")
parser.add_argument('datafile', type=str, help='data.s16 file path')
args = parser.parse_args()
data = np.memmap(args.datafile, dtype='int16', mode='readwrite')
# signal is in data[1::2]
# last augmented signal is in data[0::2]
data[1 : - 1 : 2] = data[2 : : 2]

View File

@@ -0,0 +1,17 @@
#!/bin/bash
case $# in
9) SETUP=$1; OUTDIR=$2; NAME=$3; NUMDEVICES=$4; ROUNDS=$5; LPCNEXT=$6; LPCNET=$7; TESTSUITE=$8; TESTITEMS=$9;;
*) echo "multi_run.sh setup outdir name num_devices rounds_per_device lpcnext_repo lpcnet_repo testsuite_repo testitems"; exit;;
esac
LOOPRUN=${LPCNEXT}/loop_run.sh
mkdir -p $OUTDIR
for ((i = 0; i < $NUMDEVICES; i++))
do
echo "launching job queue for device $i"
nohup bash $LOOPRUN $SETUP $OUTDIR "$NAME" "cuda:$i" $ROUNDS $LPCNEXT $LPCNET $TESTSUITE $TESTITEMS > $OUTDIR/job_${i}_out.txt &
done

View File

@@ -0,0 +1,22 @@
#!/bin/bash
case $# in
3) FEATURES=$1; FOLDER=$2; PYTHON=$3;;
*) echo "run_inference_test.sh <features file> <output folder> <python path>"; exit;;
esac
SCRIPTFOLDER=$(dirname "$0")
mkdir -p $FOLDER/inference_test
# update checkpoints
for fn in $(find $FOLDER -type f -name "checkpoint*.pth")
do
tmp=$(basename $fn)
tmp=${tmp%.pth}
epoch=${tmp#checkpoint_epoch_}
echo "running inference with checkpoint $fn..."
$PYTHON $SCRIPTFOLDER/../test_lpcnet.py $FEATURES $fn $FOLDER/inference_test/output_epoch_${epoch}.wav
done

View File

@@ -0,0 +1,54 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" script for updating checkpoints with new setup entries
Use this script to update older outputs with newly introduced
parameters. (Saves us the trouble of backward compatibility)
"""
import argparse
import torch
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint_file', type=str, help='checkpoint to be updated')
parser.add_argument('--model', type=str, help='model update', default=None)
args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_file, map_location='cpu')
# update model entry
if type(args.model) != type(None):
checkpoint['setup']['lpcnet']['model'] = args.model
torch.save(checkpoint, args.checkpoint_file)

View File

@@ -0,0 +1,22 @@
#!/bin/bash
case $# in
3) FOLDER=$1; MODEL=$2; PYTHON=$3;;
*) echo "update_output_folder.sh folder model python"; exit;;
esac
SCRIPTFOLDER=$(dirname "$0")
# update setup
echo "updating $FOLDER/setup.py..."
$PYTHON $SCRIPTFOLDER/update_setups.py $FOLDER/setup.yml --model $MODEL
# update checkpoints
for fn in $(find $FOLDER -type f -name "checkpoint*.pth")
do
echo "updating $fn..."
$PYTHON $SCRIPTFOLDER/update_checkpoints.py $fn --model $MODEL
done

View File

@@ -0,0 +1,57 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" script for updating setup files with new setup entries
Use this script to update older outputs with newly introduced
parameters. (Saves us the trouble of backward compatibility)
"""
import argparse
import yaml
parser = argparse.ArgumentParser()
parser.add_argument('setup_file', type=str, help='setup to be updated')
parser.add_argument('--model', type=str, help='model update', default=None)
args = parser.parse_args()
# load setup
with open(args.setup_file, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
# update model entry
if type(args.model) != type(None):
setup['lpcnet']['model'] = args.model
# dump result
with open(args.setup_file, 'w') as f:
yaml.dump(setup, f)

View File

@@ -0,0 +1,89 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import argparse
import torch
import numpy as np
from models import model_dict
from utils.data import load_features
from utils.wav import wavwrite16
debug = False
if debug:
args = type('dummy', (object,),
{
'features' : 'features.f32',
'checkpoint' : 'checkpoint.pth',
'output' : 'out.wav',
'version' : 2
})()
else:
parser = argparse.ArgumentParser()
parser.add_argument('features', type=str, help='feature file')
parser.add_argument('checkpoint', type=str, help='checkpoint file')
parser.add_argument('output', type=str, help='output file')
parser.add_argument('--version', type=int, help='feature version', default=2)
args = parser.parse_args()
torch.set_num_threads(2)
version = args.version
feature_file = args.features
checkpoint_file = args.checkpoint
output_file = args.output
if not output_file.endswith('.wav'):
output_file += '.wav'
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# check model
if not 'model' in checkpoint['setup']['lpcnet']:
print(f'warning: did not find model entry in setup, using default lpcnet')
model_name = 'lpcnet'
else:
model_name = checkpoint['setup']['lpcnet']['model']
model = model_dict[model_name](checkpoint['setup']['lpcnet']['config'])
model.load_state_dict(checkpoint['state_dict'])
data = load_features(feature_file)
output = model.generate(data['features'], data['periods'], data['lpcs'])
wavwrite16(output_file, output.numpy(), 16000)

View File

@@ -0,0 +1,272 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
try:
import git
has_git = True
except:
has_git = False
import yaml
import torch
from torch.optim.lr_scheduler import LambdaLR
from data import LPCNetDataset
from models import model_dict
from engine.lpcnet_engine import train_one_epoch, evaluate
from utils.data import load_features
from utils.wav import wavwrite16
debug = False
if debug:
args = type('dummy', (object,),
{
'setup' : 'setup.yml',
'output' : 'testout',
'device' : None,
'test_features' : None,
'finalize': False,
'initial_checkpoint': None,
'no-redirect': False
})()
else:
parser = argparse.ArgumentParser("train_lpcnet.py")
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--test-features', type=str, help='test feature file in v2 format', default=None)
parser.add_argument('--finalize', action='store_true', help='run single training round with lr=1e-5')
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of output')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
if args.finalize:
if args.initial_checkpoint is None:
raise ValueError('finalization requires initial checkpoint')
if 'sparsification' in setup['lpcnet']['config']:
for sp_job in setup['lpcnet']['config']['sparsification'].values():
sp_job['start'], sp_job['stop'] = 0, 0
setup['training']['lr'] = 1.0e-5
setup['training']['lr_decay_factor'] = 0.0
setup['training']['epochs'] = 1
checkpoint_prefix = 'checkpoint_finalize'
output_prefix = 'output_finalize'
setup_name = 'setup_finalize.yml'
output_file='out_finalize.txt'
else:
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'model' in setup['lpcnet']:
print(f'warning: did not find model entry in setup, using default lpcnet')
model_name = 'lpcnet'
else:
model_name = setup['lpcnet']['model']
# prepare output folder
if os.path.exists(args.output) and not debug and not args.finalize:
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
# prepare inference test if wanted
run_inference_test = False
if type(args.test_features) != type(None):
test_features = load_features(args.test_features)
inference_test_dir = os.path.join(args.output, 'inference_test')
os.makedirs(inference_test_dir, exist_ok=True)
run_inference_test = True
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
# load training dataset
lpcnet_config = setup['lpcnet']['config']
data = LPCNetDataset( setup['dataset'],
features=lpcnet_config['features'],
input_signals=lpcnet_config['signals'],
target=lpcnet_config['target'],
frames_per_sample=setup['training']['frames_per_sample'],
feature_history=lpcnet_config['feature_history'],
feature_lookahead=lpcnet_config['feature_lookahead'],
lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = LPCNetDataset( setup['validation_dataset'],
features=lpcnet_config['features'],
input_signals=lpcnet_config['signals'],
target=lpcnet_config['target'],
frames_per_sample=setup['training']['frames_per_sample'],
feature_history=lpcnet_config['feature_history'],
feature_lookahead=lpcnet_config['feature_lookahead'],
lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](setup['lpcnet']['config'])
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(chkpt['state_dict'])
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# push model to device
model.to(device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr)
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
# loss
criterion = torch.nn.NLLLoss()
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
best_loss = 1e9
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = new_loss
if run_validation:
print("running validation...")
validation_loss = evaluate(model, criterion, validation_dataloader, device)
checkpoint['validation_loss'] = validation_loss
if validation_loss < best_loss:
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
best_loss = validation_loss
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
# run inference test
if run_inference_test:
model.to("cpu")
print("running inference test...")
output = model.generate(test_features['features'], test_features['periods'], test_features['lpcs'])
testfilename = os.path.join(inference_test_dir, output_prefix + f'_epoch_{ep}.wav')
wavwrite16(testfilename, output.numpy(), 16000)
model.to(device)
print()

View File

@@ -0,0 +1,4 @@
from . import sparsification
from . import data
from . import pcm
from . import sample

View File

@@ -0,0 +1,141 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import torch
import numpy as np
def load_features(feature_file, version=2):
if version == 2:
layout = {
'cepstrum': [0,18],
'periods': [18, 19],
'pitch_corr': [19, 20],
'lpc': [20, 36]
}
frame_length = 36
elif version == 1:
layout = {
'cepstrum': [0,18],
'periods': [36, 37],
'pitch_corr': [37, 38],
'lpc': [39, 55],
}
frame_length = 55
else:
raise ValueError(f'unknown feature version: {version}')
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
raw_features = raw_features.reshape((-1, frame_length))
features = torch.cat(
[
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
],
dim=1
)
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
ref_data = np.memmap(reference_data_path, dtype=np.int16)
signal = np.memmap(signal_path, dtype=np.int16)
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
assert len(signal) % 160 == 0
num_frames = len(signal) // 160
mem = np.zeros(1)
for fr in range(len(signal)//160):
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
new_data[:] = 0
N = len(signal) - offset
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
def parse_warpq_scores(output_file):
""" extracts warpq scores from output file """
with open(output_file, "r") as f:
lines = f.readlines()
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
return scores
def parse_stats_file(file):
with open(file, "r") as f:
lines = f.readlines()
mean = float(lines[0].split(":")[-1])
bt_mean = float(lines[1].split(":")[-1])
top_mean = float(lines[2].split(":")[-1])
return mean, bt_mean, top_mean
def collect_test_stats(test_folder):
""" collects statistics for all discovered metrics from test folder """
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
results = dict()
content = os.listdir(test_folder)
stats_files = [file for file in content if file.startswith('stats_')]
for file in stats_files:
metric = file[len("stats_") : -len(".txt")]
if metric not in metrics:
print(f"warning: unknown metric {metric}")
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
results[metric] = [mean, bt_mean, top_mean]
return results

View File

@@ -0,0 +1,234 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" module for inspecting models during inference """
import os
import yaml
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import numpy as np
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
_state = dict()
_folder = 'endoscopy'
def get_gru_gates(gru, input, state):
hidden_size = gru.hidden_size
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
# reset gate
start, stop = 0 * hidden_size, 1 * hidden_size
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# update gate
start, stop = 1 * hidden_size, 2 * hidden_size
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# new gate
start, stop = 2 * hidden_size, 3 * hidden_size
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
def init(folder='endoscopy'):
""" sets up output folder for endoscopy data """
global _folder
_folder = folder
if not os.path.exists(folder):
os.makedirs(folder)
else:
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
def write_data(key, data, fs):
""" appends data to previous data written under key """
global _state
# convert to numpy if torch.Tensor is given
if isinstance(data, torch.Tensor):
data = data.detach().numpy()
if not key in _state:
_state[key] = {
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
'fs' : fs,
'dim' : tuple(data.shape),
'dtype' : str(data.dtype)
}
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
else:
if _state[key]['fs'] != fs:
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
if _state[key]['dtype'] != str(data.dtype):
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
if _state[key]['dim'] != tuple(data.shape):
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
_state[key]['fid'].write(data.tobytes())
def close(folder='endoscopy'):
""" clean up """
for key in _state.keys():
_state[key]['fid'].close()
def read_data(folder='endoscopy'):
""" retrieves written data as numpy arrays """
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
return_dict = dict()
for key in keys:
with open(os.path.join(folder, key + '.yml'), 'r') as f:
value = yaml.load(f.read(), yaml.FullLoader)
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
data = np.frombuffer(f.read(), dtype=value['dtype'])
value['data'] = data.reshape((-1,) + value['dim'])
return_dict[key] = value
return return_dict
def get_best_reshape(shape, target_ratio=1):
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return (1,)
num_columns = int((pixel_count / target_ratio)**.5)
while (pixel_count % num_columns):
num_columns -= 1
num_rows = pixel_count // num_columns
return (num_rows, num_columns)
def get_type_and_shape(shape):
# can happen if data is one dimensional
if len(shape) == 0:
shape = (1,)
# calculate pixel count
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return 'plot', (1, )
# stay with shape if already 2-dimensional
if len(shape) == 2:
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
return 'image', shape
return 'image', get_best_reshape(shape)
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
# determine plot setup
num_keys = len(data.keys())
num_rows = int((num_keys * 3/4) ** .5)
num_cols = (num_keys + num_rows - 1) // num_rows
fig, axs = plt.subplots(num_rows, num_cols)
fig.set_size_inches(num_cols * 5, num_rows * 5)
display = dict()
fs_max = max([val['fs'] for val in data.values()])
num_samples = max([val['data'].shape[0] for val in data.values()])
keys = sorted(data.keys())
# inspect data
for i, key in enumerate(keys):
axs[i // num_cols, i % num_cols].title.set_text(key)
display[key] = dict()
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
display[key]['down_factor'] = data[key]['fs'] / fs_max
start_index = max(start_index, half_signal_window_length)
while stop_index < 0:
stop_index += num_samples
stop_index = min(stop_index, num_samples - half_signal_window_length)
# actual plotting
frames = []
for index in range(start_index, stop_index):
ims = []
for i, key in enumerate(keys):
feature_index = int(round(index * display[key]['down_factor']))
if display[key]['type'] == 'plot':
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
elif display[key]['type'] == 'image':
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
frames.append(ims)
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
if not filename.endswith('.mp4'):
filename += '.mp4'
ani.save(filename)

View File

@@ -0,0 +1,3 @@
from .dual_fc import DualFC
from .subconditioner import AdditiveSubconditioner, ModulativeSubconditioner, ConcatenativeSubconditioner
from .pcm_embeddings import PCMEmbedding, DifferentiablePCMEmbedding

View File

@@ -0,0 +1,44 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
class DualFC(nn.Module):
def __init__(self, input_dim, output_dim):
super(DualFC, self).__init__()
self.dense1 = nn.Linear(input_dim, output_dim)
self.dense2 = nn.Linear(input_dim, output_dim)
self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
self.beta = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
def forward(self, x):
return self.alpha * torch.tanh(self.dense1(x)) + self.beta * torch.tanh(self.dense2(x))

View File

@@ -0,0 +1,71 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" module implementing PCM embeddings for LPCNet """
import math as m
import torch
from torch import nn
class PCMEmbedding(nn.Module):
def __init__(self, embed_dim=128, num_levels=256):
super(PCMEmbedding, self).__init__()
self.embed_dim = embed_dim
self.num_levels = num_levels
self.embedding = nn.Embedding(self.num_levels, self.num_dim)
# initialize
with torch.no_grad():
num_rows, num_cols = self.num_levels, self.embed_dim
a = m.sqrt(12) * (torch.rand(num_rows, num_cols) - 0.5)
for i in range(num_rows):
a[i, :] += m.sqrt(12) * (i - num_rows / 2)
self.embedding.weight[:, :] = 0.1 * a
def forward(self, x):
return self.embeddint(x)
class DifferentiablePCMEmbedding(PCMEmbedding):
def __init__(self, embed_dim, num_levels=256):
super(DifferentiablePCMEmbedding, self).__init__(embed_dim, num_levels)
def forward(self, x):
x_int = (x - torch.floor(x)).detach().long()
x_frac = x - x_int
x_next = torch.minimum(x_int + 1, self.num_levels)
embed_0 = self.embedding(x_int)
embed_1 = self.embedding(x_next)
return (1 - x_frac) * embed_0 + x_frac * embed_1

View File

@@ -0,0 +1,497 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
from re import sub
import torch
from torch import nn
def get_subconditioner( method,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
**kwargs):
subconditioner_dict = {
'additive' : AdditiveSubconditioner,
'concatenative' : ConcatenativeSubconditioner,
'modulative' : ModulativeSubconditioner
}
return subconditioner_dict[method](number_of_subsamples,
pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
class Subconditioner(nn.Module):
def __init__(self):
""" upsampling by subconditioning
Upsamples a sequence of states conditioning on pcm signals and
optionally a feature vector.
"""
super(Subconditioner, self).__init__()
def forward(self, states, signals, features=None):
raise Exception("Base class should not be called")
def single_step(self, index, state, signals, features):
raise Exception("Base class should not be called")
def get_output_dim(self, index):
raise Exception("Base class should not be called")
class AdditiveSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
**kwargs):
""" subconditioning by addition """
super(AdditiveSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
if self.pcm_embedding_size != self.state_size:
raise ValueError('For additive subconditioning state and embedding '
+ f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
self.embeddings = [None]
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.sum(embed, dim=2)
new_states = new_states + embed
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.sum(embed_signals, dim=-2)
c_state = state + c
return c_state
def get_output_dim(self, index):
return self.state_size
def get_average_flops_per_step(self):
s = self.number_of_subsamples
flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
return flops
class ConcatenativeSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
recurrent=True,
**kwargs):
""" subconditioning by concatenation """
super(ConcatenativeSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
self.recurrent = recurrent
self.embeddings = []
start_index = 0
if self.recurrent:
start_index = 1
self.embeddings.append(None)
for i in range(start_index, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
if self.recurrent:
c_states = [states]
start = 1
else:
c_states = []
start = 0
new_states = states
for i in range(start, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
if self.recurrent:
new_states = torch.cat((new_states, embed), dim=-1)
else:
new_states = torch.cat((states, embed), dim=-1)
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0 and self.recurrent:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.flatten(embed_signals, -2)
if not self.recurrent and index > 0:
# overwrite previous conditioning vector
c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
else:
c_state = torch.cat((state, c), dim=-1)
return c_state
return c_state
def get_average_flops_per_step(self):
return 0
def get_output_dim(self, index):
if self.recurrent:
return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
else:
return self.state_size + self.pcm_embedding_size * self.number_of_signals
class ModulativeSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
state_recurrent=False,
**kwargs):
""" subconditioning by modulation """
super(ModulativeSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
self.state_recurrent = state_recurrent
self.hidden_size = self.pcm_embedding_size * self.number_of_signals
if self.state_recurrent:
self.hidden_size += self.pcm_embedding_size
self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
self.embeddings = [None]
self.alphas = [None]
self.betas = [None]
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
self.add_module('alpha_dense_' + str(i), self.alphas[-1])
self.betas.append(nn.Linear(self.hidden_size, self.state_size))
self.add_module('beta_dense_' + str(i), self.betas[-1])
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
if self.state_recurrent:
comp_states = self.state_transform(new_states)
embed = torch.cat((embed, comp_states), dim=-1)
alpha = torch.tanh(self.alphas[i](embed))
beta = torch.tanh(self.betas[i](embed))
# new state obtained by modulating previous state
new_states = torch.tanh((1 + alpha) * new_states + beta)
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.flatten(embed_signals, -2)
if self.state_recurrent:
r_state = self.state_transform(state)
c = torch.cat((c, r_state), dim=-1)
alpha = torch.tanh(self.alphas[index](c))
beta = torch.tanh(self.betas[index](c))
c_state = torch.tanh((1 + alpha) * state + beta)
return c_state
return c_state
def get_output_dim(self, index):
return self.state_size
def get_average_flops_per_step(self):
s = self.number_of_subsamples
# estimate activation by 10 flops
# c_state = torch.tanh((1 + alpha) * state + beta)
flops = 13 * self.state_size
# hidden size
hidden_size = self.number_of_signals * self.pcm_embedding_size
if self.state_recurrent:
hidden_size += self.pcm_embedding_size
# counting 2 * A * B flops for Linear(A, B)
# alpha = torch.tanh(self.alphas[index](c))
# beta = torch.tanh(self.betas[index](c))
flops += 4 * hidden_size * self.state_size + 20 * self.state_size
# r_state = self.state_transform(state)
if self.state_recurrent:
flops += 2 * self.state_size * self.pcm_embedding_size
# average over steps
flops *= (s - 1) / s
return flops
class ComparitiveSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
error_index=-1,
apply_gate=True,
normalize=False):
""" subconditioning by comparison """
super(ComparitiveSubconditioner, self).__init__()
self.comparison_size = self.pcm_embedding_size
self.error_position = error_index
self.apply_gate = apply_gate
self.normalize = normalize
self.state_transform = nn.Linear(self.state_size, self.comparison_size)
self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
if self.apply_gate:
self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
# embeddings and state transforms
self.embeddings = [None]
self.alpha_denses = [None]
self.beta_denses = [None]
self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
self.add_module('state_transform_0', self.state_transforms[0])
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
state_transform = nn.Linear(self.state_size, self.comparison_size)
self.add_module('state_transform_' + str(i), state_transform)
self.state_transforms.append(state_transform)
self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
def forward(self, states, signals):
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
comp_states = self.state_transforms[i](new_states)
alpha = torch.tanh(self.alpha_dense(embed))
beta = torch.tanh(self.beta_dense(embed))
# new state obtained by modulating previous state
new_states = torch.tanh((1 + alpha) * comp_states + beta)
c_states.append(new_states)
return c_states

View File

@@ -0,0 +1,65 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
def find(a, v):
try:
idx = a.index(v)
except:
idx = -1
return idx
def interleave_tensors(tensors, dim=-2):
""" interleave list of tensors along sequence dimension """
x = torch.cat([x.unsqueeze(dim) for x in tensors], dim=dim)
x = torch.flatten(x, dim - 1, dim)
return x
def _interleave(x, pcm_levels=256):
repeats = pcm_levels // (2*x.size(-1))
x = x.unsqueeze(-1)
p = torch.flatten(torch.repeat_interleave(torch.cat((x, 1 - x), dim=-1), repeats, dim=-1), -2)
return p
def get_pdf_from_tree(x):
pcm_levels = x.size(-1)
p = _interleave(x[..., 1:2])
n = 4
while n <= pcm_levels:
p = p * _interleave(x[..., n//2:n])
n *= 2
return p

View File

@@ -0,0 +1,35 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
def clip_to_int16(x):
int_min = -2**15
int_max = 2**15 - 1
x_clipped = max(int_min, min(x, int_max))
return x_clipped

View File

@@ -0,0 +1,44 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
def sample_excitation(probs, pitch_corr):
norm = lambda x : x / (x.sum() + 1e-18)
# lowering the temperature
probs = norm(probs ** (1 + max(0, 1.5 * pitch_corr - 0.5)))
# cut-off tails
probs = norm(torch.maximum(probs - 0.002 , torch.FloatTensor([0])))
# sample
exc = torch.multinomial(probs.squeeze(), 1)
return exc

View File

@@ -0,0 +1,2 @@
from .gru_sparsifier import GRUSparsifier
from .common import sparsify_matrix, calculate_gru_flops_per_step

View File

@@ -0,0 +1,121 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
""" sparsifies matrix with specified block size
Parameters:
-----------
matrix : torch.tensor
matrix to sparsify
density : int
target density
block_size : [int, int]
block size dimensions
keep_diagonal : bool
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
"""
m, n = matrix.shape
m1, n1 = block_size
if m % m1 or n % n1:
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
# extract diagonal if keep_diagonal = True
if keep_diagonal:
if m != n:
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
to_spare = torch.diag(torch.diag(matrix))
matrix = matrix - to_spare
else:
to_spare = torch.zeros_like(matrix)
# calculate energy in sub-blocks
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
x = x ** 2
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
number_of_blocks = (m * n) // (m1 * n1)
number_of_survivors = round(number_of_blocks * density)
# masking threshold
if number_of_survivors == 0:
threshold = 0
else:
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
# create mask
mask = torch.ones_like(block_energies)
mask[block_energies < threshold] = 0
mask = torch.repeat_interleave(mask, m1, dim=0)
mask = torch.repeat_interleave(mask, n1, dim=1)
# perform masking
masked_matrix = mask * matrix + to_spare
if return_mask:
return masked_matrix, mask
else:
return masked_matrix
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
input_size = gru.input_size
hidden_size = gru.hidden_size
flops = 0
input_density = (
sparsification_dict.get('W_ir', [1])[0]
+ sparsification_dict.get('W_in', [1])[0]
+ sparsification_dict.get('W_iz', [1])[0]
) / 3
recurrent_density = (
sparsification_dict.get('W_hr', [1])[0]
+ sparsification_dict.get('W_hn', [1])[0]
+ sparsification_dict.get('W_hz', [1])[0]
) / 3
# input matrix vector multiplications
if not drop_input:
flops += 2 * 3 * input_size * hidden_size * input_density
# recurrent matrix vector multiplications
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
# biases
flops += 6 * hidden_size
# activations estimated by 10 flops per activation
flops += 30 * hidden_size
return flops

View File

@@ -0,0 +1,187 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from .common import sparsify_matrix
class GRUSparsifier:
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
should be kept.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> gru = torch.nn.GRU(10, 20)
>>> sparsify_dict = {
... 'W_ir' : (0.5, [2, 2], False),
... 'W_iz' : (0.6, [2, 2], False),
... 'W_in' : (0.7, [2, 2], False),
... 'W_hr' : (0.1, [4, 4], True),
... 'W_hz' : (0.2, [4, 4], True),
... 'W_hn' : (0.3, [4, 4], True),
... }
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
# just copying parameters...
self.start = start
self.stop = stop
self.interval = interval
self.exponent = exponent
self.task_list = task_list
# ... and setting counter to 0
self.step_counter = 0
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
def step(self, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
# compute current interpolation factor
self.step_counter += 1
if self.step_counter < self.start:
return
elif self.step_counter < self.stop:
# update only every self.interval-th interval
if self.step_counter % self.interval:
return
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
else:
alpha = 0
with torch.no_grad():
for gru, params in self.task_list:
hidden_size = gru.hidden_size
# input weights
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
if key in params:
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
density, # density
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
print(f"sparsification mask {key} changed for gru {gru}")
self.last_masks[key] = new_mask
# recurrent weights
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
if key in params:
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
density,
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
print(f"sparsification mask {key} changed for gru {gru}")
self.last_masks[key] = new_mask
if __name__ == "__main__":
print("Testing sparsifier")
gru = torch.nn.GRU(10, 20)
sparsify_dict = {
'W_ir' : (0.5, [2, 2], False),
'W_iz' : (0.6, [2, 2], False),
'W_in' : (0.7, [2, 2], False),
'W_hr' : (0.1, [4, 4], True),
'W_hz' : (0.2, [4, 4], True),
'W_hn' : (0.3, [4, 4], True),
}
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
for i in range(100):
sparsifier.step(verbose=True)

View File

@@ -0,0 +1,157 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
from models import multi_rate_lpcnet
import copy
setup_dict = dict()
dataset_template_v2 = {
'version' : 2,
'feature_file' : 'features.f32',
'signal_file' : 'data.s16',
'frame_length' : 160,
'feature_frame_length' : 36,
'signal_frame_length' : 2,
'feature_dtype' : 'float32',
'signal_dtype' : 'int16',
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [18, 19], 'pitch_corr': [19, 20], 'lpc': [20, 36]},
'signal_frame_layout' : {'last_signal' : 0, 'signal': 1} # signal, last_signal, error, prediction
}
dataset_template_v1 = {
'version' : 1,
'feature_file' : 'features.f32',
'signal_file' : 'data.u8',
'frame_length' : 160,
'feature_frame_length' : 55,
'signal_frame_length' : 4,
'feature_dtype' : 'float32',
'signal_dtype' : 'uint8',
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [36, 37], 'pitch_corr': [37, 38], 'lpc': [39, 55]},
'signal_frame_layout' : {'last_signal' : 0, 'prediction' : 1, 'last_error': 2, 'error': 3} # signal, last_signal, error, prediction
}
# lpcnet
lpcnet_config = {
'frame_size' : 160,
'gru_a_units' : 384,
'gru_b_units' : 64,
'feature_conditioning_dim' : 128,
'feature_conv_kernel_size' : 3,
'period_levels' : 257,
'period_embedding_dim' : 64,
'signal_embedding_dim' : 128,
'signal_levels' : 256,
'feature_dimension' : 19,
'output_levels' : 256,
'lpc_gamma' : 0.9,
'features' : ['cepstrum', 'periods', 'pitch_corr'],
'signals' : ['last_signal', 'prediction', 'last_error'],
'input_layout' : { 'signals' : {'last_signal' : 0, 'prediction' : 1, 'last_error' : 2},
'features' : {'cepstrum' : [0, 18], 'pitch_corr' : [18, 19]} },
'target' : 'error',
'feature_history' : 2,
'feature_lookahead' : 2,
'sparsification' : {
'gru_a' : {
'start' : 10000,
'stop' : 30000,
'interval' : 100,
'exponent' : 3,
'params' : {
'W_hr' : (0.05, [4, 8], True),
'W_hz' : (0.05, [4, 8], True),
'W_hn' : (0.2, [4, 8], True)
},
},
'gru_b' : {
'start' : 10000,
'stop' : 30000,
'interval' : 100,
'exponent' : 3,
'params' : {
'W_ir' : (0.5, [4, 8], False),
'W_iz' : (0.5, [4, 8], False),
'W_in' : (0.5, [4, 8], False)
},
}
},
'add_reference_phase' : False,
'reference_phase_dim' : 0
}
# multi rate
subconditioning = {
'subconditioning_a' : {
'number_of_subsamples' : 2,
'method' : 'modulative',
'signals' : ['last_signal', 'prediction', 'last_error'],
'pcm_embedding_size' : 64,
'kwargs' : dict()
},
'subconditioning_b' : {
'number_of_subsamples' : 2,
'method' : 'modulative',
'signals' : ['last_signal', 'prediction', 'last_error'],
'pcm_embedding_size' : 64,
'kwargs' : dict()
}
}
multi_rate_lpcnet_config = lpcnet_config.copy()
multi_rate_lpcnet_config['subconditioning'] = subconditioning
training_default = {
'batch_size' : 256,
'epochs' : 20,
'lr' : 1e-3,
'lr_decay_factor' : 2.5e-5,
'adam_betas' : [0.9, 0.99],
'frames_per_sample' : 15
}
lpcnet_setup = {
'dataset' : '/local/datasets/lpcnet_training',
'lpcnet' : {'config' : lpcnet_config, 'model': 'lpcnet'},
'training' : training_default
}
multi_rate_lpcnet_setup = copy.deepcopy(lpcnet_setup)
multi_rate_lpcnet_setup['lpcnet']['config'] = multi_rate_lpcnet_config
multi_rate_lpcnet_setup['lpcnet']['model'] = 'multi_rate'
setup_dict = {
'lpcnet' : lpcnet_setup,
'multi_rate' : multi_rate_lpcnet_setup
}

View File

@@ -0,0 +1,58 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import math as m
import torch
def ulaw2lin(u):
scale_1 = 32768.0 / 255.0
u = u - 128
s = torch.sign(u)
u = torch.abs(u)
return s * scale_1 * (torch.exp(u / 128. * m.log(256)) - 1)
def lin2ulawq(x):
scale = 255.0 / 32768.0
s = torch.sign(x)
x = torch.abs(x)
u = s * (128 * torch.log(1 + scale * x) / m.log(256))
u = torch.clip(128 + torch.round(u), 0, 255)
return u
def lin2ulaw(x):
scale = 255.0 / 32768.0
s = torch.sign(x)
x = torch.abs(x)
u = s * (128 * torch.log(1 + scale * x) / torch.log(256))
u = torch.clip(128 + u, 0, 255)
return u

View File

@@ -0,0 +1,43 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import wave
def wavwrite16(filename, x, fs):
""" writes x as int16 to file with name filename
If x.dtype is int16 x is written as is. Otherwise,
it is scaled by 2**15 - 1 and converted to int16.
"""
if x.dtype != 'int16':
x = ((2**15 - 1) * x).astype('int16')
with wave.open(filename, 'wb') as f:
f.setparams((1, 2, fs, len(x), 'NONE', ""))
f.writeframes(x.tobytes())

View File

@@ -0,0 +1,18 @@
## Neural Pitch Estimation
- Dataset Installation
1. Download and unzip PTDB Dataset:
wget https://www2.spsc.tugraz.at/databases/PTDB-TUG/SPEECH_DATA_ZIPPED.zip
unzip SPEECH_DATA_ZIPPED.zip
2. Inside "SPEECH DATA" above, run ptdb_process.sh to combine male/female
3. To Download and combine demand, simply run download_demand.sh
- LPCNet preparation
1. To extract xcorr, add lpcnet_extractor.c and add relevant functions to lpcnet_enc.c, add source for headers/c files and Makefile.am, and compile to generate ./lpcnet_xcorr_extractor object
- Dataset Augmentation and training (check out arguments to each of the following)
1. Run data_augmentation.py
2. Run training.py using augmented data
3. Run experiments.py

View File

@@ -0,0 +1,149 @@
"""
Perform Data Augmentation (Gain, Additive Noise, Random Filtering) on Input TTS Data
1. Read in chunks and compute clean pitch first
2. Then add in augmentation (Noise/Level/Response)
- Adds filtered noise from the "Demand" dataset, https://zenodo.org/record/1227121#.XRKKxYhKiUk
- When using the Demand Dataset, consider each channel as a possible noise input, and keep the first 4 minutes of noise for training
3. Use this "augmented" audio for feature computation, and compute pitch using CREPE on the clean input
Notes: To ensure consistency with the discovered CREPE offset, we do the following
- We pad the input audio to the zero-centered CREPE estimator with 80 zeros
- We pad the input audio to our feature computation with 160 zeros to center them
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('data', type=str, help='input raw audio data')
parser.add_argument('output', type=str, help='output directory')
parser.add_argument('path_lpcnet_extractor', type=str, help='path to LPCNet extractor object file (generated on compilation)')
parser.add_argument('noise_dataset', type=str, help='Location of the Demand Datset')
parser.add_argument('--flag_xcorr', type=bool, help='Flag to additionally dump xcorr features',choices=[True,False],default = False,required = False)
parser.add_argument('--fraction_input_use', type=float, help='Fraction of input data to consider',default = 0.3,required = False)
parser.add_argument('--gpu_index', type=int, help='GPU index to use if multiple GPUs',default = 0,required = False)
parser.add_argument('--choice_augment', type=str, help='Choice of noise augmentation, either use additive synthetic noise or add noise from the demand dataset',choices = ['demand','synthetic'],default = "demand",required = False)
parser.add_argument('--fraction_clean', type=float, help='Fraction of data to keep clean (that is not augment with anything)',default = 0.2,required = False)
parser.add_argument('--chunk_size', type=int, help='Number of samples to augment with for each iteration',default = 80000,required = False)
parser.add_argument('--N', type=int, help='STFT window size',default = 320,required = False)
parser.add_argument('--H', type=int, help='STFT Hop size',default = 160,required = False)
parser.add_argument('--freq_keep', type=int, help='Number of Frequencies to keep',default = 30,required = False)
args = parser.parse_args()
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)
from utils import stft, random_filter
import numpy as np
import tqdm
import crepe
import random
import glob
import subprocess
data_full = np.memmap(args.data, dtype=np.int16,mode = 'r')
data = data_full[:(int)(args.fraction_input_use*data_full.shape[0])]
# list_features = []
list_cents = []
list_confidences = []
N = args.N
H = args.H
freq_keep = args.freq_keep
# Minimum/Maximum periods, decided by LPCNet
min_period = 32
max_period = 256
f_ref = 16000/max_period
chunk_size = args.chunk_size
num_frames_chunk = chunk_size//H
list_indices_keep = np.concatenate([np.arange(freq_keep), (N//2 + 1) + np.arange(freq_keep), 2*(N//2 + 1) + np.arange(freq_keep)])
output_IF = np.memmap(args.output + '_iffeat.f32', dtype=np.float32, shape=(((data.shape[0]//chunk_size - 1)//1)*num_frames_chunk,list_indices_keep.shape[0]), mode='w+')
if args.flag_xcorr:
output_xcorr = np.memmap(args.output + '_xcorr.f32', dtype=np.float32, shape=(((data.shape[0]//chunk_size - 1)//1)*num_frames_chunk,257), mode='w+')
fraction_clean = args.fraction_clean
noise_dataset = args.noise_dataset
for i in tqdm.trange((data.shape[0]//chunk_size - 1)//1):
chunk = data[i*chunk_size:(i + 1)*chunk_size]/(2**15 - 1)
# Clean Pitch/Confidence Estimate
# Padding input to CREPE by 80 samples to ensure it aligns
_, pitch, confidence, _ = crepe.predict(np.concatenate([np.zeros(80),chunk]), 16000, center=True, viterbi=True,verbose=0)
cent = 1200*np.log2(np.divide(pitch, f_ref, out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)
# Filter out of range pitches/confidences
confidence[pitch < 16000/max_period] = 0
confidence[pitch > 16000/min_period] = 0
# Keep fraction of data clean, augment only 1 minus the fraction
if (np.random.rand() > fraction_clean):
# Response, generate controlled/random 2nd order IIR filter and filter chunk
chunk = random_filter(chunk)
# Level/Gain response {scale by random gain between 1.0e-3 and 10}
# Generate random gain in dB and then convert to scale
g_dB = np.random.uniform(low = -60, high = 20, size = 1)
# g_dB = 0
g = 10**(g_dB/20)
# Noise Addition {Add random SNR 2nd order randomly colored noise}
# Generate noise SNR value and add corresponding noise
snr_dB = np.random.uniform(low = -20, high = 30, size = 1)
if args.choice_augment == 'synthetic':
n = np.random.randn(chunk_size)
else:
list_noisefiles = noise_dataset + '*.wav'
noise_file = random.choice(glob.glob(list_noisefiles))
n = np.memmap(noise_file, dtype=np.int16,mode = 'r')/(2**15 - 1)
rand_range = np.random.randint(low = 0, high = (n.shape[0] - 16000*60 - chunk.shape[0])) # 16000 is subtracted because we will use the last 1 minutes of noise for testing
n = n[rand_range:rand_range + chunk.shape[0]]
# Randomly filter the sampled noise as well
n = random_filter(n)
# generate random prime number between 0,500 and make those samples of noise 0 (to prevent GRU from picking up temporal patterns)
Nprime = random.choice([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541])
n[chunk_size - Nprime:] = np.zeros(Nprime)
snr_multiplier = np.sqrt((np.sum(np.abs(chunk)**2)/np.sum(np.abs(n)**2))*10**(-snr_dB/10))
chunk = g*(chunk + snr_multiplier*n)
# Zero pad input audio by 160 to center the frames
spec = stft(x = np.concatenate([np.zeros(160),chunk]), w = 'boxcar', N = N, H = H).T
phase_diff = spec*np.conj(np.roll(spec,1,axis = -1))
phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8)
feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T
feature = feature[:,list_indices_keep]
if args.flag_xcorr:
# Dump noisy audio into temp file
data_temp = np.memmap('./temp_augment.raw', dtype=np.int16, shape=(chunk.shape[0]), mode='w+')
# data_temp[:chunk.shape[0]] = (chunk/(np.max(np.abs(chunk)))*(2**15 - 1)).astype(np.int16)
data_temp[:chunk.shape[0]] = ((chunk)*(2**15 - 1)).astype(np.int16)
subprocess.run([args.path_lpcnet_extractor, './temp_augment.raw', './temp_augment_xcorr.f32'])
feature_xcorr = np.flip(np.fromfile('./temp_augment_xcorr.f32', dtype='float32').reshape((-1,256),order = 'C'),axis = 1)
ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1)
feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1)
os.remove('./temp_augment.raw')
os.remove('./temp_augment_xcorr.f32')
num_frames = min(cent.shape[0],feature.shape[0],feature_xcorr.shape[0],num_frames_chunk)
feature = feature[:num_frames,:]
cent = cent[:num_frames]
confidence = confidence[:num_frames]
feature_xcorr = feature_xcorr[:num_frames]
output_IF[i*num_frames_chunk:(i + 1)*num_frames_chunk,:] = feature
output_xcorr[i*num_frames_chunk:(i + 1)*num_frames_chunk,:] = feature_xcorr
list_cents.append(cent)
list_confidences.append(confidence)
list_cents = np.hstack(list_cents)
list_confidences = np.hstack(list_confidences)
np.save(args.output + '_pitches',np.vstack([list_cents,list_confidences]))

View File

@@ -0,0 +1,43 @@
wget https://zenodo.org/record/1227121/files/DKITCHEN_16k.zip
wget https://zenodo.org/record/1227121/files/DLIVING_16k.zip
wget https://zenodo.org/record/1227121/files/DWASHING_16k.zip
wget https://zenodo.org/record/1227121/files/NFIELD_16k.zip
wget https://zenodo.org/record/1227121/files/NPARK_16k.zip
wget https://zenodo.org/record/1227121/files/NRIVER_16k.zip
wget https://zenodo.org/record/1227121/files/OHALLWAY_16k.zip
wget https://zenodo.org/record/1227121/files/OMEETING_16k.zip
wget https://zenodo.org/record/1227121/files/OOFFICE_16k.zip
wget https://zenodo.org/record/1227121/files/PCAFETER_16k.zip
wget https://zenodo.org/record/1227121/files/PRESTO_16k.zip
wget https://zenodo.org/record/1227121/files/PSTATION_16k.zip
wget https://zenodo.org/record/1227121/files/TMETRO_16k.zip
wget https://zenodo.org/record/1227121/files/TCAR_16k.zip
wget https://zenodo.org/record/1227121/files/TBUS_16k.zip
wget https://zenodo.org/record/1227121/files/STRAFFIC_16k.zip
wget https://zenodo.org/record/1227121/files/SPSQUARE_16k.zip
unzip '*.zip'
mkdir -p ./combined_demand_channels/
for file in */*.wav; do
parentdir="$(dirname "$file")"
echo $parentdir
fname="$(basename "$file")"
cp $file ./combined_demand_channels/$parentdir+$fname
done

View File

@@ -0,0 +1,349 @@
"""
Evaluation script to compute the Raw Pitch Accuracy
Procedure:
- Look at all voiced frames in file
- Compute number of pitches in those frames that lie within a 50 cent threshold
RPA = (Total number of pitches within threshold summed across all files)/(Total number of voiced frames summed accross all files)
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from prettytable import PrettyTable
import numpy as np
import glob
import random
import tqdm
import torch
import librosa
import json
from utils import stft, random_filter, feature_xform
import subprocess
import crepe
from models import PitchDNN, PitchDNNIF, PitchDNNXcorr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def rca(reference,input,voicing,thresh = 25):
idx_voiced = np.where(voicing != 0)[0]
acc = np.where(np.abs(reference - input)[idx_voiced] < thresh)[0]
return acc.shape[0]
def sweep_rca(reference,input,voicing,thresh = 25,ind_arr = np.arange(-10,10)):
l = []
for i in ind_arr:
l.append(rca(reference,np.roll(input,i),voicing,thresh))
l = np.array(l)
return np.max(l)
def rpa(model,device = 'cpu',data_format = 'if'):
list_files = glob.glob('/home/ubuntu/Code/Datasets/SPEECH DATA/combined_mic_16k_raw/*.raw')
dir_f0 = '/home/ubuntu/Code/Datasets/SPEECH DATA/combine_f0_ptdb/'
# random_shuffle = list(np.random.permutation(len(list_files)))
random.shuffle(list_files)
list_files = list_files[:1000]
C_all = 0
C_all_m = 0
C_all_f = 0
list_rca_model_all = []
list_rca_male_all = []
list_rca_female_all = []
thresh = 50
N = 320
H = 160
freq_keep = 30
for idx in tqdm.trange(len(list_files)):
audio_file = list_files[idx]
file_name = os.path.basename(list_files[idx])[:-4]
audio = np.memmap(list_files[idx], dtype=np.int16)/(2**15 - 1)
offset = 432
audio = audio[offset:]
rmse = np.squeeze(librosa.feature.rms(y = audio,frame_length = 320,hop_length = 160))
spec = stft(x = np.concatenate([np.zeros(160),audio]), w = 'boxcar', N = N, H = H).T
phase_diff = spec*np.conj(np.roll(spec,1,axis = -1))
phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8)
idx_save = np.concatenate([np.arange(freq_keep),(N//2 + 1) + np.arange(freq_keep),2*(N//2 + 1) + np.arange(freq_keep)])
feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T
feature_if = feature[:,idx_save]
data_temp = np.memmap('./temp.raw', dtype=np.int16, shape=(audio.shape[0]), mode='w+')
data_temp[:audio.shape[0]] = (audio/(np.max(np.abs(audio)))*(2**15 - 1)).astype(np.int16)
subprocess.run(["../../../lpcnet_xcorr_extractor", './temp.raw', './temp_xcorr.f32'])
feature_xcorr = np.flip(np.fromfile('./temp_xcorr.f32', dtype='float32').reshape((-1,256),order = 'C'),axis = 1)
ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1)
feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1)
# feature_xcorr = feature_xform(feature_xcorr)
os.remove('./temp.raw')
os.remove('./temp_xcorr.f32')
if data_format == 'if':
feature = feature_if
elif data_format == 'xcorr':
feature = feature_xcorr
else:
indmin = min(feature_if.shape[0],feature_xcorr.shape[0])
feature = np.concatenate([feature_xcorr[:indmin,:],feature_if[:indmin,:]],-1)
pitch_file_name = dir_f0 + "ref" + os.path.basename(list_files[idx])[3:-4] + ".f0"
pitch = np.loadtxt(pitch_file_name)[:,0]
voicing = np.loadtxt(pitch_file_name)[:,1]
indmin = min(voicing.shape[0],rmse.shape[0],pitch.shape[0])
pitch = pitch[:indmin]
voicing = voicing[:indmin]
rmse = rmse[:indmin]
voicing = voicing*(rmse > 0.05*np.max(rmse))
if "mic_F" in audio_file:
idx_correct = np.where(pitch < 125)
voicing[idx_correct] = 0
cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int')
model_cents = model(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device))
model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy()
num_frames = min(cent.shape[0],model_cents.shape[0])
pitch = pitch[:num_frames]
cent = cent[:num_frames]
voicing = voicing[:num_frames]
model_cents = model_cents[:num_frames]
voicing_all = np.copy(voicing)
# Forcefully make regions where pitch is <65 or greater than 500 unvoiced for relevant accurate pitch comparisons for our model
force_out_of_pitch = np.where(np.logical_or(pitch < 65,pitch > 500)==True)
voicing_all[force_out_of_pitch] = 0
C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
list_rca_model_all.append(rca(cent,model_cents,voicing_all,thresh))
if "mic_M" in audio_file:
list_rca_male_all.append(rca(cent,model_cents,voicing_all,thresh))
C_all_m = C_all_m + np.where(voicing_all != 0)[0].shape[0]
else:
list_rca_female_all.append(rca(cent,model_cents,voicing_all,thresh))
C_all_f = C_all_f + np.where(voicing_all != 0)[0].shape[0]
list_rca_model_all = np.array(list_rca_model_all)
list_rca_male_all = np.array(list_rca_male_all)
list_rca_female_all = np.array(list_rca_female_all)
x = PrettyTable()
x.field_names = ["Experiment", "Mean RPA"]
x.add_row(["Both all pitches", np.sum(list_rca_model_all)/C_all])
x.add_row(["Male all pitches", np.sum(list_rca_male_all)/C_all_m])
x.add_row(["Female all pitches", np.sum(list_rca_female_all)/C_all_f])
print(x)
return None
def cycle_eval(checkpoint_list, noise_type = 'synthetic', noise_dataset = None, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = None,fraction = 0.1,thresh = 50):
"""
Cycle through SNR evaluation for list of checkpoints
"""
list_files = glob.glob(ptdb_dataset_path + 'combined_mic_16k/*.raw')
dir_f0 = ptdb_dataset_path + 'combined_reference_f0/'
random.shuffle(list_files)
list_files = list_files[:(int)(fraction*len(list_files))]
dict_models = {}
list_snr.append(np.inf)
for f in checkpoint_list:
if (f!='crepe') and (f!='lpcnet'):
checkpoint = torch.load(f, map_location='cpu')
dict_params = checkpoint['config']
if dict_params['data_format'] == 'if':
from models import large_if_ccode as model
pitch_nn = PitchDNNIF(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim'])
elif dict_params['data_format'] == 'xcorr':
from models import large_xcorr as model
pitch_nn = PitchDNNXcorr(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
else:
from models import large_joint as model
pitch_nn = PitchDNN(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
pitch_nn.load_state_dict(checkpoint['state_dict'])
N = dict_params['window_size']
H = dict_params['hop_factor']
freq_keep = dict_params['freq_keep']
list_mean = []
list_std = []
for snr_dB in list_snr:
C_all = 0
C_correct = 0
for idx in tqdm.trange(len(list_files)):
audio_file = list_files[idx]
file_name = os.path.basename(list_files[idx])[:-4]
audio = np.memmap(list_files[idx], dtype=np.int16)/(2**15 - 1)
offset = 432
audio = audio[offset:]
rmse = np.squeeze(librosa.feature.rms(y = audio,frame_length = N,hop_length = H))
if noise_type != 'synthetic':
list_noisefiles = noise_dataset + '*.wav'
noise_file = random.choice(glob.glob(list_noisefiles))
n = np.memmap(noise_file, dtype=np.int16,mode = 'r')/(2**15 - 1)
rand_range = np.random.randint(low = 0, high = (16000*60*5 - audio.shape[0])) # Last 1 minute of noise used for testing
n = n[rand_range:rand_range + audio.shape[0]]
else:
n = np.random.randn(audio.shape[0])
n = random_filter(n)
snr_multiplier = np.sqrt((np.sum(np.abs(audio)**2)/np.sum(np.abs(n)**2))*10**(-snr_dB/10))
audio = audio + snr_multiplier*n
spec = stft(x = np.concatenate([np.zeros(160),audio]), w = 'boxcar', N = N, H = H).T
phase_diff = spec*np.conj(np.roll(spec,1,axis = -1))
phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8)
idx_save = np.concatenate([np.arange(freq_keep),(N//2 + 1) + np.arange(freq_keep),2*(N//2 + 1) + np.arange(freq_keep)])
feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T
feature_if = feature[:,idx_save]
data_temp = np.memmap('./temp.raw', dtype=np.int16, shape=(audio.shape[0]), mode='w+')
# data_temp[:audio.shape[0]] = (audio/(np.max(np.abs(audio)))*(2**15 - 1)).astype(np.int16)
data_temp[:audio.shape[0]] = ((audio)*(2**15 - 1)).astype(np.int16)
subprocess.run(["../../../lpcnet_xcorr_extractor", './temp.raw', './temp_xcorr.f32'])
feature_xcorr = np.flip(np.fromfile('./temp_xcorr.f32', dtype='float32').reshape((-1,256),order = 'C'),axis = 1)
ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1)
feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1)
os.remove('./temp.raw')
os.remove('./temp_xcorr.f32')
if dict_params['data_format'] == 'if':
feature = feature_if
elif dict_params['data_format'] == 'xcorr':
feature = feature_xcorr
else:
indmin = min(feature_if.shape[0],feature_xcorr.shape[0])
feature = np.concatenate([feature_xcorr[:indmin,:],feature_if[:indmin,:]],-1)
pitch_file_name = dir_f0 + "ref" + os.path.basename(list_files[idx])[3:-4] + ".f0"
pitch = np.loadtxt(pitch_file_name)[:,0]
voicing = np.loadtxt(pitch_file_name)[:,1]
indmin = min(voicing.shape[0],rmse.shape[0],pitch.shape[0])
pitch = pitch[:indmin]
voicing = voicing[:indmin]
rmse = rmse[:indmin]
voicing = voicing*(rmse > 0.05*np.max(rmse))
if "mic_F" in audio_file:
idx_correct = np.where(pitch < 125)
voicing[idx_correct] = 0
cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int')
model_cents = pitch_nn(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device))
model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy()
num_frames = min(cent.shape[0],model_cents.shape[0])
pitch = pitch[:num_frames]
cent = cent[:num_frames]
voicing = voicing[:num_frames]
model_cents = model_cents[:num_frames]
voicing_all = np.copy(voicing)
# Forcefully make regions where pitch is <65 or greater than 500 unvoiced for relevant accurate pitch comparisons for our model
force_out_of_pitch = np.where(np.logical_or(pitch < 65,pitch > 500)==True)
voicing_all[force_out_of_pitch] = 0
C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
C_correct = C_correct + rca(cent,model_cents,voicing_all,thresh)
list_mean.append(C_correct/C_all)
else:
fname = f
list_mean = []
list_std = []
for snr_dB in list_snr:
C_all = 0
C_correct = 0
for idx in tqdm.trange(len(list_files)):
audio_file = list_files[idx]
file_name = os.path.basename(list_files[idx])[:-4]
audio = np.memmap(list_files[idx], dtype=np.int16)/(2**15 - 1)
offset = 432
audio = audio[offset:]
rmse = np.squeeze(librosa.feature.rms(y = audio,frame_length = 320,hop_length = 160))
if noise_type != 'synthetic':
list_noisefiles = noise_dataset + '*.wav'
noise_file = random.choice(glob.glob(list_noisefiles))
n = np.memmap(noise_file, dtype=np.int16,mode = 'r')/(2**15 - 1)
rand_range = np.random.randint(low = 0, high = (16000*60*5 - audio.shape[0])) # Last 1 minute of noise used for testing
n = n[rand_range:rand_range + audio.shape[0]]
else:
n = np.random.randn(audio.shape[0])
n = random_filter(n)
snr_multiplier = np.sqrt((np.sum(np.abs(audio)**2)/np.sum(np.abs(n)**2))*10**(-snr_dB/10))
audio = audio + snr_multiplier*n
if (f == 'crepe'):
_, model_frequency, _, _ = crepe.predict(np.concatenate([np.zeros(80),audio]), 16000, viterbi=True,center=True,verbose=0)
model_cents = 1200*np.log2(model_frequency/(16000/256) + 1.0e-8)
else:
data_temp = np.memmap('./temp.raw', dtype=np.int16, shape=(audio.shape[0]), mode='w+')
# data_temp[:audio.shape[0]] = (audio/(np.max(np.abs(audio)))*(2**15 - 1)).astype(np.int16)
data_temp[:audio.shape[0]] = ((audio)*(2**15 - 1)).astype(np.int16)
subprocess.run(["../../../lpcnet_xcorr_extractor", './temp.raw', './temp_xcorr.f32', './temp_period.f32'])
feature_xcorr = np.fromfile('./temp_period.f32', dtype='float32')
model_cents = 1200*np.log2((256/feature_xcorr + 1.0e-8) + 1.0e-8)
os.remove('./temp.raw')
os.remove('./temp_xcorr.f32')
os.remove('./temp_period.f32')
pitch_file_name = dir_f0 + "ref" + os.path.basename(list_files[idx])[3:-4] + ".f0"
pitch = np.loadtxt(pitch_file_name)[:,0]
voicing = np.loadtxt(pitch_file_name)[:,1]
indmin = min(voicing.shape[0],rmse.shape[0],pitch.shape[0])
pitch = pitch[:indmin]
voicing = voicing[:indmin]
rmse = rmse[:indmin]
voicing = voicing*(rmse > 0.05*np.max(rmse))
if "mic_F" in audio_file:
idx_correct = np.where(pitch < 125)
voicing[idx_correct] = 0
cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int')
num_frames = min(cent.shape[0],model_cents.shape[0])
pitch = pitch[:num_frames]
cent = cent[:num_frames]
voicing = voicing[:num_frames]
model_cents = model_cents[:num_frames]
voicing_all = np.copy(voicing)
# Forcefully make regions where pitch is <65 or greater than 500 unvoiced for relevant accurate pitch comparisons for our model
force_out_of_pitch = np.where(np.logical_or(pitch < 65,pitch > 500)==True)
voicing_all[force_out_of_pitch] = 0
C_all = C_all + np.where(voicing_all != 0)[0].shape[0]
C_correct = C_correct + rca(cent,model_cents,voicing_all,thresh)
list_mean.append(C_correct/C_all)
dict_models[fname] = {}
dict_models[fname]['list_SNR'] = list_mean[:-1]
dict_models[fname]['inf'] = list_mean[-1]
return dict_models

View File

@@ -0,0 +1,38 @@
"""
Running the experiments;
1. RCA vs SNR for our models, CREPE, LPCNet
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('ptdb_root', type=str, help='Root Directory for PTDB generated by running ptdb_process.sh ')
parser.add_argument('output', type=str, help='Output dump file name')
parser.add_argument('method', type=str, help='Output Directory to save experiment dumps',choices=['model','lpcnet','crepe'])
parser.add_argument('--noise_dataset', type=str, help='Location of the Demand Datset',default = './',required=False)
parser.add_argument('--noise_type', type=str, help='Type of additive noise',default = 'synthetic',choices=['synthetic','demand'],required=False)
parser.add_argument('--pth_file', type=str, help='.pth file to analyze',default = './',required = False)
parser.add_argument('--fraction_files_analyze', type=float, help='Fraction of PTDB dataset to test on',default = 1,required = False)
parser.add_argument('--threshold_rca', type=float, help='Cent threshold when computing RCA',default = 50,required = False)
parser.add_argument('--gpu_index', type=int, help='GPU index to use if multiple GPUs',default = 0,required = False)
args = parser.parse_args()
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)
import json
from evaluation import cycle_eval
if args.method == 'model':
dict_store = cycle_eval([args.pth_file], noise_type = args.noise_type, noise_dataset = args.noise_dataset, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = args.ptdb_root,fraction = args.fraction_files_analyze,thresh = args.threshold_rca)
else:
dict_store = cycle_eval([args.method], noise_type = args.noise_type, noise_dataset = args.noise_dataset, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = args.ptdb_root,fraction = args.fraction_files_analyze,thresh = args.threshold_rca)
dict_store["method"] = args.method
if args.method == 'model':
dict_store['pth'] = args.pth_file
with open(args.output, 'w') as fp:
json.dump(dict_store, fp)

View File

@@ -0,0 +1,109 @@
"""
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', type=str, help='model checkpoint')
parser.add_argument('output_dir', type=str, help='output folder')
args = parser.parse_args()
import torch
import numpy as np
from models import PitchDNN
from wexchange.torch import dump_torch_weights
from wexchange.c_export import CWriter, print_vector
def c_export(args, model):
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
writer = CWriter(os.path.join(args.output_dir, "pitchdnn_data"), message=message, model_struct_name='PitchDNN')
writer.header.write(
f"""
#include "opus_types.h"
"""
)
dense_layers = [
('if_upsample.0', "dense_if_upsampler_1"),
('if_upsample.2', "dense_if_upsampler_2"),
('downsample.0', "dense_downsampler"),
("upsample.0", "dense_final_upsampler")
]
for name, export_name in dense_layers:
layer = model.get_submodule(name)
dump_torch_weights(writer, layer, name=export_name, verbose=True, quantize=True, scale=None)
conv_layers = [
('conv.1', "conv2d_1"),
('conv.4', "conv2d_2")
]
for name, export_name in conv_layers:
layer = model.get_submodule(name)
dump_torch_weights(writer, layer, name=export_name, verbose=True)
gru_layers = [
("GRU", "gru_1"),
]
max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=True, scale=None, recurrent_scale=None)
for name, export_name in gru_layers])
writer.header.write(
f"""
#define PITCH_DNN_MAX_RNN_UNITS {max_rnn_units}
"""
)
writer.close()
if __name__ == "__main__":
os.makedirs(args.output_dir, exist_ok=True)
model = PitchDNN()
checkpoint = torch.load(args.checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
c_export(args, model)

View File

@@ -0,0 +1,178 @@
"""
Pitch Estimation Models and dataloaders
- Classification Based (Input features, output logits)
"""
import torch
import numpy as np
class PitchDNNIF(torch.nn.Module):
def __init__(self, input_dim=88, gru_dim=64, output_dim=192):
super().__init__()
self.activation = torch.nn.Tanh()
self.initial = torch.nn.Linear(input_dim, gru_dim)
self.hidden = torch.nn.Linear(gru_dim, gru_dim)
self.gru = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, batch_first=True)
self.upsample = torch.nn.Linear(gru_dim, output_dim)
def forward(self, x):
x = self.initial(x)
x = self.activation(x)
x = self.hidden(x)
x = self.activation(x)
x,_ = self.gru(x)
x = self.upsample(x)
x = self.activation(x)
x = x.permute(0,2,1)
return x
class PitchDNNXcorr(torch.nn.Module):
def __init__(self, input_dim=90, gru_dim=64, output_dim=192):
super().__init__()
self.activation = torch.nn.Tanh()
self.conv = torch.nn.Sequential(
torch.nn.ZeroPad2d((2, 0, 1, 1)),
torch.nn.Conv2d(1, 8, 3, bias=True),
self.activation,
torch.nn.ZeroPad2d((2,0,1,1)),
torch.nn.Conv2d(8, 8, 3, bias=True),
self.activation,
torch.nn.ZeroPad2d((2,0,1,1)),
torch.nn.Conv2d(8, 1, 3, bias=True),
self.activation,
)
self.downsample = torch.nn.Sequential(
torch.nn.Linear(input_dim, gru_dim),
self.activation
)
self.GRU = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, num_layers=1, batch_first=True)
self.upsample = torch.nn.Sequential(
torch.nn.Linear(gru_dim,output_dim),
self.activation
)
def forward(self, x):
x = self.conv(x.unsqueeze(-1).permute(0,3,2,1)).squeeze(1)
x,_ = self.GRU(self.downsample(x.permute(0,2,1)))
x = self.upsample(x).permute(0,2,1)
return x
class PitchDNN(torch.nn.Module):
"""
Joint IF-xcorr
1D CNN on IF, merge with xcorr, 2D CNN on merged + GRU
"""
def __init__(self,input_IF_dim=88, input_xcorr_dim=224, gru_dim=64, output_dim=192):
super().__init__()
self.activation = torch.nn.Tanh()
self.if_upsample = torch.nn.Sequential(
torch.nn.Linear(input_IF_dim,64),
self.activation,
torch.nn.Linear(64,64),
self.activation,
)
self.conv = torch.nn.Sequential(
torch.nn.ZeroPad2d((2,0,1,1)),
torch.nn.Conv2d(1, 4, 3, bias=True),
self.activation,
torch.nn.ZeroPad2d((2,0,1,1)),
torch.nn.Conv2d(4, 1, 3, bias=True),
self.activation,
)
self.downsample = torch.nn.Sequential(
torch.nn.Linear(64 + input_xcorr_dim, gru_dim),
self.activation
)
self.GRU = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, num_layers=1, batch_first=True)
self.upsample = torch.nn.Sequential(
torch.nn.Linear(gru_dim, output_dim)
)
def forward(self, x):
xcorr_feat = x[:,:,:224]
if_feat = x[:,:,224:]
xcorr_feat = self.conv(xcorr_feat.unsqueeze(-1).permute(0,3,2,1)).squeeze(1).permute(0,2,1)
if_feat = self.if_upsample(if_feat)
x = torch.cat([xcorr_feat,if_feat],axis = - 1)
x,_ = self.GRU(self.downsample(x))
x = self.upsample(x).permute(0,2,1)
return x
# Dataloaders
class Loader(torch.utils.data.Dataset):
def __init__(self, features_if, file_pitch, confidence_threshold=0.4, dimension_if=30, context=100):
self.if_feat = np.memmap(features_if, dtype=np.float32).reshape(-1,3*dimension_if)
# Resolution of 20 cents
self.cents = np.rint(np.load(file_pitch)[0,:]/20)
self.cents = np.clip(self.cents,0,179)
self.confidence = np.load(file_pitch)[1,:]
# Filter confidence for CREPE
self.confidence[self.confidence < confidence_threshold] = 0
self.context = context
# Clip both to same size
size_common = min(self.if_feat.shape[0], self.cents.shape[0])
self.if_feat = self.if_feat[:size_common,:]
self.cents = self.cents[:size_common]
self.confidence = self.confidence[:size_common]
frame_max = self.if_feat.shape[0]//context
self.if_feat = np.reshape(self.if_feat[:frame_max*context, :],(frame_max, context,3*dimension_if))
self.cents = np.reshape(self.cents[:frame_max * context],(frame_max, context))
self.confidence = np.reshape(self.confidence[:frame_max*context],(frame_max, context))
def __len__(self):
return self.if_feat.shape[0]
def __getitem__(self, index):
return torch.from_numpy(self.if_feat[index,:,:]), torch.from_numpy(self.cents[index]), torch.from_numpy(self.confidence[index])
class PitchDNNDataloader(torch.utils.data.Dataset):
def __init__(self, features, file_pitch, confidence_threshold=0.4, context=100, choice_data='both'):
self.feat = np.memmap(features, mode='r', dtype=np.int8).reshape(-1,312)
self.xcorr = self.feat[:,:224]
self.if_feat = self.feat[:,224:]
ground_truth = np.memmap(file_pitch, mode='r', dtype=np.float32).reshape(-1,2)
self.cents = np.rint(60*np.log2(ground_truth[:,0]/62.5))
mask = (self.cents>=0).astype('float32') * (self.cents<=180).astype('float32')
self.cents = np.clip(self.cents,0,179)
self.confidence = ground_truth[:,1] * mask
# Filter confidence for CREPE
self.confidence[self.confidence < confidence_threshold] = 0
self.context = context
self.choice_data = choice_data
frame_max = self.if_feat.shape[0]//context
self.if_feat = np.reshape(self.if_feat[:frame_max*context,:], (frame_max, context, 88))
self.cents = np.reshape(self.cents[:frame_max*context], (frame_max,context))
self.xcorr = np.reshape(self.xcorr[:frame_max*context,:], (frame_max,context, 224))
self.confidence = np.reshape(self.confidence[:frame_max*context], (frame_max, context))
def __len__(self):
return self.if_feat.shape[0]
def __getitem__(self, index):
if self.choice_data == 'both':
return torch.cat([torch.from_numpy((1./127)*self.xcorr[index,:,:]), torch.from_numpy((1./127)*self.if_feat[index,:,:])], dim=-1), torch.from_numpy(self.cents[index]), torch.from_numpy(self.confidence[index])
elif self.choice_data == 'if':
return torch.from_numpy((1./127)*self.if_feat[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
else:
return torch.from_numpy((1./127)*self.xcorr[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])

View File

@@ -0,0 +1,179 @@
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('features', type=str, help='Features generated from dump_data')
parser.add_argument('data', type=str, help='Data generated from dump_data (offset by 5ms)')
parser.add_argument('output', type=str, help='output .f32 feature file with replaced neural pitch')
parser.add_argument('checkpoint', type=str, help='model checkpoint file')
parser.add_argument('path_lpcnet_extractor', type=str, help='path to LPCNet extractor object file (generated on compilation)')
parser.add_argument('--device', type=str, help='compute device',default = None,required = False)
parser.add_argument('--replace_xcorr', type = bool, default = False, help='Replace LPCNet xcorr with updated one')
args = parser.parse_args()
import os
from utils import stft, random_filter
import subprocess
import numpy as np
import json
import torch
import tqdm
from models import PitchDNNIF, PitchDNNXcorr, PitchDNN
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device is not None:
device = torch.device(args.device)
# Loading the appropriate model
checkpoint = torch.load(args.checkpoint, map_location='cpu')
dict_params = checkpoint['config']
if dict_params['data_format'] == 'if':
pitch_nn = PitchDNNIF(dict_params['freq_keep']*3, dict_params['gru_dim'], dict_params['output_dim'])
elif dict_params['data_format'] == 'xcorr':
pitch_nn = PitchDNNXcorr(dict_params['xcorr_dim'], dict_params['gru_dim'], dict_params['output_dim'])
else:
pitch_nn = PitchDNN(dict_params['freq_keep']*3, dict_params['xcorr_dim'], dict_params['gru_dim'], dict_params['output_dim'])
pitch_nn.load_state_dict(checkpoint['state_dict'])
pitch_nn = pitch_nn.to(device)
N = dict_params['window_size']
H = dict_params['hop_factor']
freq_keep = dict_params['freq_keep']
os.environ["OMP_NUM_THREADS"] = "16"
def run_lpc(signal, lpcs, frame_length=160):
num_frames, lpc_order = lpcs.shape
prediction = np.concatenate(
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
)
error = signal[lpc_order :] - prediction
return prediction, error
if __name__ == "__main__":
args = parser.parse_args()
features = np.memmap(args.features, dtype=np.float32,mode = 'r').reshape((-1, 36))
data = np.memmap(args.data, dtype=np.int16,mode = 'r').reshape((-1, 2))
num_frames = features.shape[0]
feature_dim = features.shape[1]
assert feature_dim == 36
output = np.memmap(args.output, dtype=np.float32, shape=(num_frames, feature_dim), mode='w+')
output[:, :36] = features
# lpc coefficients and signal
lpcs = features[:, 20:36]
sig = data[:, 1]
# parameters
# constants
pitch_min = 32
pitch_max = 256
lpc_order = 16
fs = 16000
frame_length = 160
overlap_frames = 100
chunk_size = 10000
history_length = frame_length * overlap_frames
history = np.zeros(history_length, dtype=np.int16)
pitch_position=18
xcorr_position=19
conf_position=36
num_frames = len(sig) // 160 - 1
frame_start = 0
frame_stop = min(frame_start + chunk_size, num_frames)
signal_start = 0
signal_stop = frame_stop * frame_length
niters = (num_frames - 1)//chunk_size
for i in tqdm.trange(niters):
if (frame_start > num_frames - 1):
break
chunk = np.concatenate((history, sig[signal_start:signal_stop]))
chunk_la = np.concatenate((history, sig[signal_start:signal_stop + 80]))
# Feature computation
spec = stft(x = np.concatenate([np.zeros(80),chunk_la/(2**15 - 1)]), w = 'boxcar', N = N, H = H).T
phase_diff = spec*np.conj(np.roll(spec,1,axis = -1))
phase_diff = phase_diff/(np.abs(phase_diff) + 1.0e-8)
idx_save = np.concatenate([np.arange(freq_keep),(N//2 + 1) + np.arange(freq_keep),2*(N//2 + 1) + np.arange(freq_keep)])
feature = np.concatenate([np.log(np.abs(spec) + 1.0e-8),np.real(phase_diff),np.imag(phase_diff)],axis = 0).T
feature_if = feature[:,idx_save]
data_temp = np.memmap('./temp_featcompute_' + dict_params['data_format'] + '_.raw', dtype=np.int16, shape=(chunk.shape[0]), mode='w+')
data_temp[:chunk.shape[0]] = chunk_la[80:].astype(np.int16)
subprocess.run([args.path_lpcnet_extractor, './temp_featcompute_' + dict_params['data_format'] + '_.raw', './temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw'])
feature_xcorr = np.flip(np.fromfile('./temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw', dtype='float32').reshape((-1,256),order = 'C'),axis = 1)
ones_zero_lag = np.expand_dims(np.ones(feature_xcorr.shape[0]),-1)
feature_xcorr = np.concatenate([ones_zero_lag,feature_xcorr],axis = -1)
os.remove('./temp_featcompute_' + dict_params['data_format'] + '_.raw')
os.remove('./temp_featcompute_xcorr_' + dict_params['data_format'] + '_.raw')
if dict_params['data_format'] == 'if':
feature = feature_if
elif dict_params['data_format'] == 'xcorr':
feature = feature_xcorr
else:
indmin = min(feature_if.shape[0],feature_xcorr.shape[0])
feature = np.concatenate([feature_xcorr[:indmin,:],feature_if[:indmin,:]],-1)
# Compute pitch with my model
model_cents = pitch_nn(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device))
model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy()
frequency = 62.5*2**(model_cents/1200)
frequency = frequency[overlap_frames : overlap_frames + frame_stop - frame_start]
# convert frequencies to periods
periods = np.round(fs / frequency)
periods = np.clip(periods, pitch_min, pitch_max)
output[frame_start:frame_stop, pitch_position] = (periods - 100) / 50
frame_offset = (pitch_max + frame_length - 1) // frame_length
offset = frame_offset * frame_length
padding = lpc_order
if frame_start < frame_offset:
lpc_coeffs = np.concatenate((np.zeros((frame_offset - frame_start, lpc_order), dtype=np.float32), lpcs[:frame_stop]))
else:
lpc_coeffs = lpcs[frame_start - frame_offset : frame_stop]
pred, error = run_lpc(chunk[history_length - offset - padding :], lpc_coeffs, frame_length=frame_length)
xcorr = np.zeros(frame_stop - frame_start)
for i, p in enumerate(periods.astype(np.int16)):
if p > 0:
f1 = error[offset + i * frame_length : offset + (i + 1) * frame_length]
f2 = error[offset + i * frame_length - p : offset + (i + 1) * frame_length - p]
xcorr[i] = np.dot(f1, f2) / np.sqrt(np.dot(f1, f1) * np.dot(f2, f2) + 1e-6)
output[frame_start:frame_stop, xcorr_position] = xcorr - 0.5
# update buffers and indices
history = chunk[-history_length :]
frame_start += chunk_size
frame_stop += chunk_size
frame_stop = min(frame_stop, num_frames)
signal_start = frame_start * frame_length
signal_stop = frame_stop * frame_length

View File

@@ -0,0 +1,34 @@
# Copy into PTDB root directory and run to combine all the male/female raw audio/references into below directories
# Make folder for combined audio
mkdir -p './combined_mic_16k/'
# Make folder for combined pitch reference
mkdir -p './combined_reference_f0/'
# Resample Male Audio
for i in ./MALE/MIC/**/*.wav; do
j="$(basename "$i" .wav)"
echo $j
sox -r 48000 -b 16 -e signed-integer "$i" -r 16000 -b 16 -e signed-integer ./combined_mic_16k/$j.raw
done
# Resample Female Audio
for i in ./FEMALE/MIC/**/*.wav; do
j="$(basename "$i" .wav)"
echo $j
sox -r 48000 -b 16 -e signed-integer "$i" -r 16000 -b 16 -e signed-integer ./combined_mic_16k/$j.raw
done
# Shift Male reference pitch files
for i in ./MALE/REF/**/*.f0; do
j="$(basename "$i" .wav)"
echo $j
cp "$i" ./combined_reference_f0/
done
# Shift Female reference pitch files
for i in ./FEMALE/REF/**/*.f0; do
j="$(basename "$i" .wav)"
echo $j
cp "$i" ./combined_reference_f0/
done

View File

@@ -0,0 +1,72 @@
"""
Perform Data Augmentation (Gain, Additive Noise, Random Filtering) on Input TTS Data
1. Read in chunks and compute clean pitch first
2. Then add in augmentation (Noise/Level/Response)
- Adds filtered noise from the "Demand" dataset, https://zenodo.org/record/1227121#.XRKKxYhKiUk
- When using the Demand Dataset, consider each channel as a possible noise input, and keep the first 4 minutes of noise for training
3. Use this "augmented" audio for feature computation, and compute pitch using CREPE on the clean input
Notes: To ensure consistency with the discovered CREPE offset, we do the following
- We pad the input audio to the zero-centered CREPE estimator with 80 zeros
- We pad the input audio to our feature computation with 160 zeros to center them
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('data', type=str, help='input raw audio data')
parser.add_argument('output', type=str, help='output directory')
parser.add_argument('--gpu-index', type=int, help='GPU index to use if multiple GPUs',default = 0,required = False)
parser.add_argument('--chunk-size-frames', type=int, help='Number of frames to process at a time',default = 100000,required = False)
args = parser.parse_args()
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)
import numpy as np
import tqdm
import crepe
data = np.memmap(args.data, dtype=np.int16,mode = 'r')
# list_features = []
list_cents = []
list_confidences = []
min_period = 32
max_period = 256
f_ref = 16000/max_period
chunk_size_frames = args.chunk_size_frames
chunk_size = chunk_size_frames*160
nb_chunks = (data.shape[0]+79)//chunk_size+1
output_data = np.zeros((0,2),dtype='float32')
for i in tqdm.trange(nb_chunks):
if i==0:
chunk = np.concatenate([np.zeros(80),data[:chunk_size-80]])
elif i==nb_chunks-1:
chunk = data[i*chunk_size-80:]
else:
chunk = data[i*chunk_size-80:(i+1)*chunk_size-80]
chunk = chunk/np.array(32767.,dtype='float32')
# Clean Pitch/Confidence Estimate
# Padding input to CREPE by 80 samples to ensure it aligns
_, pitch, confidence, _ = crepe.predict(chunk, 16000, center=True, viterbi=True,verbose=0)
pitch = pitch[:chunk_size_frames]
confidence = confidence[:chunk_size_frames]
# Filter out of range pitches/confidences
confidence[pitch < 16000/max_period] = 0
confidence[pitch > 16000/min_period] = 0
pitch = np.reshape(pitch, (-1, 1))
confidence = np.reshape(confidence, (-1, 1))
out = np.concatenate([pitch, confidence], axis=-1, dtype='float32')
output_data = np.concatenate([output_data, out], axis=0)
output_data.tofile(args.output)

View File

@@ -0,0 +1,162 @@
"""
Training the neural pitch estimator
"""
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('features', type=str, help='.f32 IF Features for training (generated by augmentation script)')
parser.add_argument('features_pitch', type=str, help='.npy Pitch file for training (generated by augmentation script)')
parser.add_argument('output_folder', type=str, help='Output directory to store the model weights and config')
parser.add_argument('data_format', type=str, help='Choice of Input Data',choices=['if','xcorr','both'])
parser.add_argument('--gpu_index', type=int, help='GPU index to use if multiple GPUs',default = 0,required = False)
parser.add_argument('--confidence_threshold', type=float, help='Confidence value below which pitch will be neglected during training',default = 0.4,required = False)
parser.add_argument('--context', type=int, help='Sequence length during training',default = 100,required = False)
parser.add_argument('--N', type=int, help='STFT window size',default = 320,required = False)
parser.add_argument('--H', type=int, help='STFT Hop size',default = 160,required = False)
parser.add_argument('--xcorr_dimension', type=int, help='Dimension of Input cross-correlation',default = 257,required = False)
parser.add_argument('--freq_keep', type=int, help='Number of Frequencies to keep',default = 30,required = False)
parser.add_argument('--gru_dim', type=int, help='GRU Dimension',default = 64,required = False)
parser.add_argument('--output_dim', type=int, help='Output dimension',default = 192,required = False)
parser.add_argument('--learning_rate', type=float, help='Learning Rate',default = 1.0e-3,required = False)
parser.add_argument('--epochs', type=int, help='Number of training epochs',default = 50,required = False)
parser.add_argument('--choice_cel', type=str, help='Choice of Cross Entropy Loss (default or robust)',choices=['default','robust'],default = 'default',required = False)
parser.add_argument('--prefix', type=str, help="prefix for model export, default: model", default='model')
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
args = parser.parse_args()
# import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)
# Fixing the seeds for reproducability
import time
np_seed = int(time.time())
torch_seed = int(time.time())
import torch
torch.manual_seed(torch_seed)
import numpy as np
np.random.seed(np_seed)
from utils import count_parameters
import tqdm
from models import PitchDNN, PitchDNNIF, PitchDNNXcorr, PitchDNNDataloader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.data_format == 'if':
pitch_nn = PitchDNNIF(3 * args.freq_keep - 2, args.gru_dim, args.output_dim)
elif args.data_format == 'xcorr':
pitch_nn = PitchDNNXcorr(args.xcorr_dimension, args.gru_dim, args.output_dim)
else:
pitch_nn = PitchDNN(3 * args.freq_keep - 2, 224, args.gru_dim, args.output_dim)
if type(args.initial_checkpoint) != type(None):
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
pitch_nn.load_state_dict(checkpoint['state_dict'], strict=False)
dataset_training = PitchDNNDataloader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
labels_one_hot = torch.nn.functional.one_hot(labels.long(),nmax)
if choice == 'default':
# Categorical Cross Entropy
CE = -torch.sum(torch.log(logits_softmax*labels_one_hot + 1.0e-6)*labels_one_hot,dim=-1)
CE = torch.mean(confidence*CE)
else:
# Robust Cross Entropy
CE = (1.0/q)*(1 - torch.sum(torch.pow(logits_softmax*labels_one_hot + 1.0e-7,q),dim=-1) )
CE = torch.sum(confidence*CE)
return CE
def accuracy(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
pred_pitch = torch.argmax(logits_softmax, 2)
accuracy = (pred_pitch != labels.long())*1.
return 1.-torch.mean(confidence*accuracy)
train_dataset, test_dataset = torch.utils.data.random_split(dataset_training, [0.95,0.05], generator=torch.Generator().manual_seed(torch_seed))
batch_size = 256
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
pitch_nn = pitch_nn.to(device)
num_params = count_parameters(pitch_nn)
learning_rate = args.learning_rate
model_opt = torch.optim.Adam(pitch_nn.parameters(), lr = learning_rate)
num_epochs = args.epochs
for epoch in range(num_epochs):
losses = []
accs = []
pitch_nn.train()
with tqdm.tqdm(train_dataloader) as train_epoch:
for i, (xi, yi, ci) in enumerate(train_epoch):
yi, xi, ci = yi.to(device, non_blocking=True), xi.to(device, non_blocking=True), ci.to(device, non_blocking=True)
pi = pitch_nn(xi.float())
loss = loss_custom(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
acc = accuracy(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
acc = acc.detach()
model_opt.zero_grad()
loss.backward()
model_opt.step()
losses.append(loss.item())
accs.append(acc.item())
avg_loss = np.mean(losses)
avg_acc = np.mean(accs)
train_epoch.set_postfix({"Train Epoch" : epoch, "Train Loss":avg_loss, "acc" : avg_acc.item()})
if epoch % 5 == 0:
pitch_nn.eval()
losses = []
with tqdm.tqdm(test_dataloader) as test_epoch:
for i, (xi, yi, ci) in enumerate(test_epoch):
yi, xi, ci = yi.to(device, non_blocking=True), xi.to(device, non_blocking=True), ci.to(device, non_blocking=True)
pi = pitch_nn(xi.float())
loss = loss_custom(logits = pi,labels = yi,confidence = ci,choice = args.choice_cel,nmax = args.output_dim)
losses.append(loss.item())
avg_loss = np.mean(losses)
test_epoch.set_postfix({"Epoch" : epoch, "Test Loss":avg_loss})
pitch_nn.eval()
config = dict(
data_format=args.data_format,
epochs=num_epochs,
window_size= args.N,
hop_factor= args.H,
freq_keep=args.freq_keep,
batch_size=batch_size,
learning_rate=learning_rate,
confidence_threshold=args.confidence_threshold,
model_parameters=num_params,
np_seed=np_seed,
torch_seed=torch_seed,
xcorr_dim=args.xcorr_dimension,
dim_input=3*args.freq_keep - 2,
gru_dim=args.gru_dim,
output_dim=args.output_dim,
choice_cel=args.choice_cel,
context=args.context,
)
model_save_path = os.path.join(args.output_folder, f"{args.prefix}_{args.data_format}.pth")
checkpoint = {
'state_dict': pitch_nn.state_dict(),
'config': config
}
torch.save(checkpoint, model_save_path)

View File

@@ -0,0 +1,59 @@
"""
Utility functions that are commonly used
"""
import numpy as np
from scipy.signal import windows, lfilter
from prettytable import PrettyTable
# Source: https://gist.github.com/thongonary/026210fc186eb5056f2b6f1ca362d912
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
param = parameter.numel()
table.add_row([name, param])
total_params+=param
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
def stft(x, w = 'boxcar', N = 320, H = 160):
x = np.concatenate([x,np.zeros(N)])
# win_custom = np.concatenate([windows.hann(80)[:40],np.ones(240),windows.hann(80)[40:]])
return np.stack([np.fft.rfft(x[i:i + N]*windows.get_window(w,N)) for i in np.arange(0,x.shape[0]-N,H)])
def random_filter(x):
# Randomly filter x with second order IIR filter with coefficients in between -3/8,3/8
filter_coeff = np.random.uniform(low = -3.0/8, high = 3.0/8, size = 4)
b = [1,filter_coeff[0],filter_coeff[1]]
a = [1,filter_coeff[2],filter_coeff[3]]
return lfilter(b,a,x)
def feature_xform(feature):
"""
Take as input the (N * 256) xcorr features output by LPCNet and perform the following
1. Downsample and Upsample by 2 (followed by smoothing)
2. Append positional embeddings (of dim k) coresponding to each xcorr lag
"""
from scipy.signal import resample_poly, lfilter
feature_US = lfilter([0.25,0.5,0.25],[1],resample_poly(feature,2,1,axis = 1),axis = 1)[:,:feature.shape[1]]
feature_DS = lfilter([0.5,0.5],[1],resample_poly(feature,1,2,axis = 1),axis = 1)
Z_append = np.zeros((feature.shape[0],feature.shape[1] - feature_DS.shape[1]))
feature_DS = np.concatenate([feature_DS,Z_append],axis = -1)
# pos_embedding = []
# for i in range(k):
# pos_embedding.append(np.cos((2**i)*np.pi*((np.repeat(np.arange(feature.shape[1]).reshape(feature.shape[1],1),feature.shape[0],axis = 1)).T/(2*feature.shape[1]))))
# pos_embedding = np.stack(pos_embedding,axis = -1)
feature = np.stack((feature_DS,feature,feature_US),axis = -1)
# feature = np.concatenate((feature,pos_embedding),axis = -1)
return feature

View File

@@ -0,0 +1,66 @@
# Opus Speech Coding Enhancement
This folder hosts models for enhancing Opus SILK.
## Environment setup
The code is tested with python 3.11. Conda setup is done via
`conda create -n osce python=3.11`
`conda activate osce`
`python -m pip install -r requirements.txt`
## Generating training data
First step is to convert all training items to 16 kHz and 16 bit pcm and then concatenate them. A convenient way to do this is to create a file list and then run
`python scripts/concatenator.py filelist 16000 dataset/clean.s16 --db_min -40 --db_max 0`
which on top provides some random scaling. Data is taken from the datasets listed in dnn/datasets.txt and the exact list of items used for training and validation is
located in dnn/torch/osce/resources.
Second step is to run a patched version of opus_demo in the dataset folder, which will produce the coded output and add feature files. To build the patched opus_demo binary, check out the exp-neural-silk-enhancement branch and build opus_demo the usual way. Then run
`cd dataset && <path_to_patched_opus_demo>/opus_demo voip 16000 1 9000 -silk_random_switching 249 clean.s16 coded.s16 `
The argument to -silk_random_switching specifies the number of frames after which parameters are switched randomly.
## Regression loss based training
Create a default setup for LACE or NoLACE via
`python make_default_setup.py model.yml --model lace/nolace --path2dataset <path2dataset>`
Then run
`python train_model.py model.yml <output folder> --no-redirect`
for running the training script in foreground or
`nohup python train_model.py model.yml <output folder> &`
to run it in background. In the latter case the output is written to `<output folder>/out.txt`.
## Adversarial training (NoLACE only)
Create a default setup for NoLACE via
`python make_default_setup.py nolace_adv.yml --model nolace --adversarial --path2dataset <path2dataset>`
Then run
`python adv_train_model.py nolace_adv.yml <output folder> --no-redirect`
for running the training script in foreground or
`nohup python adv_train_model.py nolace_adv.yml <output folder> &`
to run it in background. In the latter case the output is written to `<output folder>/out.txt`.
## Inference
Generating inference data is analogous to generating training data. Given an item 'item1.wav' run
`mkdir item1.se && sox item1.wav -r 16000 -e signed-integer -b 16 item1.raw && cd item1.se && <path_to_patched_opus_demo>/opus_demo voip 16000 1 <bitrate> ../item1.raw noisy.s16`
The folder item1.se then serves as input for the test_model.py script or for the --testdata argument of train_model.py resp. adv_train_model.py
autogen.sh downloads pre-trained model weights to the subfolder dnn/models of the main repo.

View File

@@ -0,0 +1,462 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
import math as m
import random
import yaml
from tqdm import tqdm
try:
import git
has_git = True
except:
has_git = False
import torch
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from scipy.io import wavfile
import numpy as np
import pesq
from data import SilkEnhancementSet
from models import model_dict
from utils.silk_features import load_inference_data
from utils.misc import count_parameters, retain_grads, get_grad_norm, create_weights
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'name' in setup['model']:
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
model_name = 'pitchpostfilter'
else:
model_name = setup['model']['name']
# prepare output folder
if os.path.exists(args.output):
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
ref = None
if args.testdata is not None:
testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
inference_test = True
inference_folder = os.path.join(args.output, 'inference_test')
os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
try:
ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
except:
pass
else:
inference_test = False
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
lr_gen = lr * setup['training']['gen_lr_reduction']
lambda_feat = setup['training']['lambda_feat']
lambda_reg = setup['training']['lambda_reg']
adv_target = setup['training'].get('adv_target', 'target')
# load training dataset
data_config = setup['data']
data = SilkEnhancementSet(setup['dataset'], **data_config)
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
# create discriminator
disc_name = setup['discriminator']['name']
disc = model_dict[disc_name](
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
)
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
# disc optimizer
parameters = [p for p in disc.parameters() if p.requires_grad]
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location=device)
model.load_state_dict(chkpt['state_dict'])
if 'disc_state_dict' in chkpt:
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
disc.load_state_dict(chkpt['disc_state_dict'])
if 'optimizer_state_dict' in chkpt:
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
if 'disc_optimizer_state_dict' in chkpt:
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
if 'scheduler_state_disc' in chkpt:
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
# if 'torch_rng_state' in chkpt:
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
# torch.set_rng_state(chkpt['torch_rng_state'])
if 'numpy_rng_state' in chkpt:
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
np.random.set_state(chkpt['numpy_rng_state'])
if 'python_rng_state' in chkpt:
print(f"setting Python RNG state from {args.initial_checkpoint}...")
random.setstate(chkpt['python_rng_state'])
# loss
w_l1 = setup['training']['loss']['w_l1']
w_lm = setup['training']['loss']['w_lm']
w_slm = setup['training']['loss']['w_slm']
w_sc = setup['training']['loss']['w_sc']
w_logmel = setup['training']['loss']['w_logmel']
w_wsc = setup['training']['loss']['w_wsc']
w_xcorr = setup['training']['loss']['w_xcorr']
w_sxcorr = setup['training']['loss']['w_sxcorr']
w_l2 = setup['training']['loss']['w_l2']
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
logmelloss = MRLogMelLoss().to(device)
def xcorr_loss(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
return torch.mean(loss)
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
def td_l1(y_true, y_pred, pow=0):
dims = list(range(1, len(y_true.shape)))
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
return torch.mean(tmp)
def criterion(x, y):
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
print("summary:")
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
if ref is not None:
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
print(f"initial MOS (PESQ): {initial_mos}")
best_loss = 1e9
log_interval = 10
m_r = 0
m_f = 0
s_r = 1
s_f = 1
def optimizer_to(optim, device):
for param in optim.state.values():
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
optimizer_to(optimizer, device)
optimizer_to(optimizer_disc, device)
retain_grads(model)
retain_grads(disc)
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
model.to(device)
disc.to(device)
model.train()
disc.train()
running_disc_loss = 0
running_adv_loss = 0
running_feature_loss = 0
running_reg_loss = 0
running_disc_grad_norm = 0
running_model_grad_norm = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target'].to(device)
disc_target = batch[adv_target].to(device)
# calculate model output
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
# discriminator update
scores_gen = disc(output.detach())
scores_real = disc(disc_target.unsqueeze(1))
disc_loss = 0
for score in scores_gen:
disc_loss += (((score[-1]) ** 2)).mean()
m_f = 0.9 * m_f + 0.1 * score[-1].detach().mean().cpu().item()
s_f = 0.9 * s_f + 0.1 * score[-1].detach().std().cpu().item()
for score in scores_real:
disc_loss += (((1 - score[-1]) ** 2)).mean()
m_r = 0.9 * m_r + 0.1 * score[-1].detach().mean().cpu().item()
s_r = 0.9 * s_r + 0.1 * score[-1].detach().std().cpu().item()
disc_loss = 0.5 * disc_loss / len(scores_gen)
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
disc.zero_grad()
disc_loss.backward()
running_disc_grad_norm += get_grad_norm(disc).detach().cpu().item()
optimizer_disc.step()
# generator update
scores_gen = disc(output)
# calculate loss
loss_reg = criterion(output.squeeze(1), target)
num_discs = len(scores_gen)
gen_loss = 0
for score in scores_gen:
gen_loss += (((1 - score[-1]) ** 2)).mean() / num_discs
loss_feat = 0
for k in range(num_discs):
num_layers = len(scores_gen[k]) - 1
f = 4 / num_discs / num_layers
for l in range(num_layers):
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
model.zero_grad()
(gen_loss + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
optimizer.step()
# sparsification
if hasattr(model, 'sparsifier'):
model.sparsifier()
running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
running_adv_loss += gen_loss.detach().cpu().item()
running_disc_loss += disc_loss.detach().cpu().item()
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
model_gradnorm=f"{running_model_grad_norm/(i+1):8.7f}",
disc_gradnorm=f"{running_disc_grad_norm/(i+1):8.7f}",
wc=f"{100*winning_chance:5.2f}%")
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['disc_state_dict'] = disc.state_dict()
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
checkpoint['torch_rng_state'] = torch.get_rng_state()
checkpoint['numpy_rng_state'] = np.random.get_state()
checkpoint['python_rng_state'] = random.getstate()
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
if inference_test:
print("running inference test...")
out = model.process(testsignal, features, periods, numbits).cpu().numpy()
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
if ref is not None:
mos = pesq.pesq(16000, ref, out, mode='wb')
print(f"MOS (PESQ): {mos}")
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
print()
print('Done')

View File

@@ -0,0 +1,451 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
import math as m
import random
import yaml
from tqdm import tqdm
try:
import git
has_git = True
except:
has_git = False
import torch
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from scipy.io import wavfile
import numpy as np
import pesq
from data import LPCNetVocodingDataset
from models import model_dict
from utils.lpcnet_features import load_lpcnet_features
from utils.misc import count_parameters
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'name' in setup['model']:
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
model_name = 'pitchpostfilter'
else:
model_name = setup['model']['name']
# prepare output folder
if os.path.exists(args.output):
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
ref = None
# prepare inference test if wanted
inference_test = False
if type(args.test_features) != type(None):
test_features = load_lpcnet_features(args.test_features)
features = test_features['features']
periods = test_features['periods']
inference_folder = os.path.join(args.output, 'inference_test')
os.makedirs(inference_folder, exist_ok=True)
inference_test = True
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
lr_gen = lr * setup['training']['gen_lr_reduction']
lambda_feat = setup['training']['lambda_feat']
lambda_reg = setup['training']['lambda_reg']
adv_target = setup['training'].get('adv_target', 'target')
# load training dataset
data_config = setup['data']
data = LPCNetVocodingDataset(setup['dataset'], **data_config)
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
# create discriminator
disc_name = setup['discriminator']['name']
disc = model_dict[disc_name](
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
)
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
# disc optimizer
parameters = [p for p in disc.parameters() if p.requires_grad]
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location=device)
model.load_state_dict(chkpt['state_dict'])
if 'disc_state_dict' in chkpt:
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
disc.load_state_dict(chkpt['disc_state_dict'])
if 'optimizer_state_dict' in chkpt:
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
if 'disc_optimizer_state_dict' in chkpt:
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
if 'scheduler_state_disc' in chkpt:
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
# if 'torch_rng_state' in chkpt:
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
# torch.set_rng_state(chkpt['torch_rng_state'])
if 'numpy_rng_state' in chkpt:
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
np.random.set_state(chkpt['numpy_rng_state'])
if 'python_rng_state' in chkpt:
print(f"setting Python RNG state from {args.initial_checkpoint}...")
random.setstate(chkpt['python_rng_state'])
# loss
w_l1 = setup['training']['loss']['w_l1']
w_lm = setup['training']['loss']['w_lm']
w_slm = setup['training']['loss']['w_slm']
w_sc = setup['training']['loss']['w_sc']
w_logmel = setup['training']['loss']['w_logmel']
w_wsc = setup['training']['loss']['w_wsc']
w_xcorr = setup['training']['loss']['w_xcorr']
w_sxcorr = setup['training']['loss']['w_sxcorr']
w_l2 = setup['training']['loss']['w_l2']
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
logmelloss = MRLogMelLoss().to(device)
def xcorr_loss(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
return torch.mean(loss)
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
def td_l1(y_true, y_pred, pow=0):
dims = list(range(1, len(y_true.shape)))
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
return torch.mean(tmp)
def criterion(x, y):
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
print("summary:")
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
if ref is not None:
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
print(f"initial MOS (PESQ): {initial_mos}")
best_loss = 1e9
log_interval = 10
m_r = 0
m_f = 0
s_r = 1
s_f = 1
def optimizer_to(optim, device):
for param in optim.state.values():
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
optimizer_to(optimizer, device)
optimizer_to(optimizer_disc, device)
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
model.to(device)
disc.to(device)
model.train()
disc.train()
running_disc_loss = 0
running_adv_loss = 0
running_feature_loss = 0
running_reg_loss = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target'].to(device)
disc_target = batch[adv_target].to(device)
# calculate model output
output = model(batch['features'], batch['periods'])
# discriminator update
scores_gen = disc(output.detach())
scores_real = disc(disc_target.unsqueeze(1))
disc_loss = 0
for scale in scores_gen:
disc_loss += ((scale[-1]) ** 2).mean()
m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item()
s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item()
for scale in scores_real:
disc_loss += ((1 - scale[-1]) ** 2).mean()
m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item()
s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item()
disc_loss = 0.5 * disc_loss / len(scores_gen)
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
disc.zero_grad()
disc_loss.backward()
optimizer_disc.step()
# generator update
scores_gen = disc(output)
# calculate loss
loss_reg = criterion(output.squeeze(1), target)
num_discs = len(scores_gen)
loss_gen = 0
for scale in scores_gen:
loss_gen += ((1 - scale[-1]) ** 2).mean() / num_discs
loss_feat = 0
for k in range(num_discs):
num_layers = len(scores_gen[k]) - 1
f = 4 / num_discs / num_layers
for l in range(num_layers):
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
model.zero_grad()
(loss_gen + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
optimizer.step()
running_adv_loss += loss_gen.detach().cpu().item()
running_disc_loss += disc_loss.detach().cpu().item()
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
wc=f"{100*winning_chance:5.2f}%")
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['disc_state_dict'] = disc.state_dict()
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
checkpoint['torch_rng_state'] = torch.get_rng_state()
checkpoint['numpy_rng_state'] = np.random.get_state()
checkpoint['python_rng_state'] = random.getstate()
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
if inference_test:
print("running inference test...")
out = model.process(features, periods).cpu().numpy()
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
if ref is not None:
mos = pesq.pesq(16000, ref, out, mode='wb')
print(f"MOS (PESQ): {mos}")
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
print()
print('Done')

View File

@@ -0,0 +1,165 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import torch
import numpy as np
from models import model_dict
from utils import endoscopy
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint_path', type=str, help='path to folder containing checkpoints "lace_checkpoint.pth" and nolace_checkpoint.pth"')
parser.add_argument('output_folder', type=str, help='output folder for testvectors')
parser.add_argument('--debug', action='store_true', help='add debug output to output folder')
def create_adaconv_testvector(prefix, adaconv, num_frames, debug=False):
feature_dim = adaconv.feature_dim
in_channels = adaconv.in_channels
out_channels = adaconv.out_channels
frame_size = adaconv.frame_size
features = torch.randn((1, num_frames, feature_dim))
x_in = torch.randn((1, in_channels, num_frames * frame_size))
x_out = adaconv(x_in, features, debug=debug)
features = features[0].detach().numpy()
x_in = x_in[0].reshape(in_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
x_out = x_out[0].reshape(out_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
features.tofile(prefix + '_features.f32')
x_in.tofile(prefix + '_x_in.f32')
x_out.tofile(prefix + '_x_out.f32')
def create_adacomb_testvector(prefix, adacomb, num_frames, debug=False):
feature_dim = adacomb.feature_dim
in_channels = 1
frame_size = adacomb.frame_size
features = torch.randn((1, num_frames, feature_dim))
x_in = torch.randn((1, in_channels, num_frames * frame_size))
p_in = torch.randint(adacomb.kernel_size, 250, (1, num_frames))
x_out = adacomb(x_in, features, p_in, debug=debug)
features = features[0].detach().numpy()
x_in = x_in[0].permute(1, 0).detach().numpy()
p_in = p_in[0].detach().numpy().astype(np.int32)
x_out = x_out[0].permute(1, 0).detach().numpy()
features.tofile(prefix + '_features.f32')
x_in.tofile(prefix + '_x_in.f32')
p_in.tofile(prefix + '_p_in.s32')
x_out.tofile(prefix + '_x_out.f32')
def create_adashape_testvector(prefix, adashape, num_frames):
feature_dim = adashape.feature_dim
frame_size = adashape.frame_size
features = torch.randn((1, num_frames, feature_dim))
x_in = torch.randn((1, 1, num_frames * frame_size))
x_out = adashape(x_in, features)
features = features[0].detach().numpy()
x_in = x_in.flatten().detach().numpy()
x_out = x_out.flatten().detach().numpy()
features.tofile(prefix + '_features.f32')
x_in.tofile(prefix + '_x_in.f32')
x_out.tofile(prefix + '_x_out.f32')
def create_feature_net_testvector(prefix, model, num_frames):
num_features = model.num_features
num_subframes = 4 * num_frames
input_features = torch.randn((1, num_subframes, num_features))
periods = torch.randint(32, 300, (1, num_subframes))
numbits = model.numbits_range[0] + torch.rand((1, num_frames, 2)) * (model.numbits_range[1] - model.numbits_range[0])
pembed = model.pitch_embedding(periods)
nembed = torch.repeat_interleave(model.numbits_embedding(numbits).flatten(2), 4, dim=1)
full_features = torch.cat((input_features, pembed, nembed), dim=-1)
cf = model.feature_net(full_features)
input_features.float().numpy().tofile(prefix + "_in_features.f32")
periods.numpy().astype(np.int32).tofile(prefix + "_periods.s32")
numbits.float().numpy().tofile(prefix + "_numbits.f32")
full_features.detach().numpy().tofile(prefix + "_full_features.f32")
cf.detach().numpy().tofile(prefix + "_out_features.f32")
if __name__ == "__main__":
args = parser.parse_args()
os.makedirs(args.output_folder, exist_ok=True)
lace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "lace_checkpoint.pth"), map_location='cpu')
nolace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "nolace_checkpoint.pth"), map_location='cpu')
lace = model_dict['lace'](**lace_checkpoint['setup']['model']['kwargs'])
nolace = model_dict['nolace'](**nolace_checkpoint['setup']['model']['kwargs'])
lace.load_state_dict(lace_checkpoint['state_dict'])
nolace.load_state_dict(nolace_checkpoint['state_dict'])
if args.debug:
endoscopy.init(args.output_folder)
# lace af1, 1 input channel, 1 output channel
create_adaconv_testvector(os.path.join(args.output_folder, "lace_af1"), lace.af1, 5, debug=args.debug)
# nolace af1, 1 input channel, 2 output channels
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af1"), nolace.af1, 5, debug=args.debug)
# nolace af4, 2 input channel, 1 output channels
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af4"), nolace.af4, 5, debug=args.debug)
# nolace af2, 2 input channel, 2 output channels
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af2"), nolace.af2, 5, debug=args.debug)
# lace cf1
create_adacomb_testvector(os.path.join(args.output_folder, "lace_cf1"), lace.cf1, 5, debug=args.debug)
# nolace tdshape1
create_adashape_testvector(os.path.join(args.output_folder, "nolace_tdshape1"), nolace.tdshape1, 5)
# lace feature net
create_feature_net_testvector(os.path.join(args.output_folder, 'lace'), lace, 5)
if args.debug:
endoscopy.close()

View File

@@ -0,0 +1,2 @@
from .silk_enhancement_set import SilkEnhancementSet
from .lpcnet_vocoding_dataset import LPCNetVocodingDataset

View File

@@ -0,0 +1,225 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" Dataset for LPCNet training """
import os
import yaml
import torch
import numpy as np
from torch.utils.data import Dataset
scale = 255.0/32768.0
scale_1 = 32768.0/255.0
def ulaw2lin(u):
u = u - 128
s = np.sign(u)
u = np.abs(u)
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
def lin2ulaw(x):
s = np.sign(x)
x = np.abs(x)
u = (s*(128*np.log(1+scale*x)/np.log(256)))
u = np.clip(128 + np.round(u), 0, 255)
return u
def run_lpc(signal, lpcs, frame_length=160):
num_frames, lpc_order = lpcs.shape
prediction = np.concatenate(
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
)
error = signal[lpc_order :] - prediction
return prediction, error
class LPCNetVocodingDataset(Dataset):
def __init__(self,
path_to_dataset,
features=['cepstrum', 'periods', 'pitch_corr'],
target='signal',
frames_per_sample=100,
feature_history=0,
feature_lookahead=0,
lpc_gamma=1):
super().__init__()
# load dataset info
self.path_to_dataset = path_to_dataset
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
dataset = yaml.load(f, yaml.FullLoader)
# dataset version
self.version = dataset['version']
if self.version == 1:
self.getitem = self.getitem_v1
elif self.version == 2:
self.getitem = self.getitem_v2
else:
raise ValueError(f"dataset version {self.version} unknown")
# features
self.feature_history = feature_history
self.feature_lookahead = feature_lookahead
self.frame_offset = 2 + self.feature_history
self.frames_per_sample = frames_per_sample
self.input_features = features
self.feature_frame_layout = dataset['feature_frame_layout']
self.lpc_gamma = lpc_gamma
# load feature file
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
self.feature_frame_length = dataset['feature_frame_length']
assert len(self.features) % self.feature_frame_length == 0
self.features = self.features.reshape((-1, self.feature_frame_length))
# derive number of samples is dataset
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample
# signals
self.frame_length = dataset['frame_length']
self.signal_frame_layout = dataset['signal_frame_layout']
self.target = target
# load signals
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
self.signal_frame_length = dataset['signal_frame_length']
self.signals = self.signals.reshape((-1, self.signal_frame_length))
assert len(self.signals) == len(self.features) * self.frame_length
def __getitem__(self, index):
return self.getitem(index)
def getitem_v2(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
# calculate prediction and error if lpc coefficients present and prediction not given
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
# lpc coefficients with one frame lookahead
# frame positions (start one frame early for past excitation)
frame_start = self.frame_offset + self.frames_per_sample * index - 1
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
# feature positions
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
lpc_order = lpc_stop - lpc_start
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
# LPC weighting
lpc_order = lpc_stop - lpc_start
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
lpcs = lpcs * weights
# signal position (lpc_order samples as history)
signal_start = frame_start * self.frame_length - lpc_order + 1
signal_stop = frame_stop * self.frame_length + 1
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
# extract signals
offset = self.frame_length
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
# calculate error between real signal and noisy prediction
sample['error'] = sample['signal'] - sample['prediction']
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
target = torch.FloatTensor(sample[self.target]) / 2**15
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'target' : target}
def getitem_v1(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
for signal_name, index in self.signal_frame_layout.items():
sample[signal_name] = self.signals[signal_start : signal_stop, index]
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
target = torch.LongTensor(sample[self.target])
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
def __len__(self):
return self.dataset_length

View File

@@ -0,0 +1,132 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
from torch.utils.data import Dataset
import numpy as np
from utils.silk_features import silk_feature_factory
from utils.pitch import hangover, calculate_acorr_window
class SilkEnhancementSet(Dataset):
def __init__(self,
path,
frames_per_sample=100,
no_pitch_value=9,
acorr_radius=2,
pitch_hangover=8,
num_bands_clean_spec=64,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
add_offset=False,
add_double_lag_acorr=False
):
assert frames_per_sample % 4 == 0
self.frame_size = 80
self.frames_per_sample = frames_per_sample
self.no_pitch_value = no_pitch_value
self.acorr_radius = acorr_radius
self.pitch_hangover = pitch_hangover
self.num_bands_clean_spec = num_bands_clean_spec
self.num_bands_noisy_spec = num_bands_noisy_spec
self.noisy_spec_scale = noisy_spec_scale
self.add_double_lag_acorr = add_double_lag_acorr
self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
self.lpcnet_features = np.from_file(os.path.join(path, 'features_lpcnet.f32'), dtype=np.float32).reshape(-1, 36)
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
self.create_features = silk_feature_factory(no_pitch_value,
acorr_radius,
pitch_hangover,
num_bands_clean_spec,
num_bands_noisy_spec,
noisy_spec_scale,
noisy_apply_dct,
add_offset,
add_double_lag_acorr)
self.history_len = 700 if add_double_lag_acorr else 350
# discard some frames to have enough signal history
self.skip_frames = 4 * ((self.history_len + 319) // 320 + 2)
num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames
self.len = num_frames // frames_per_sample
def __len__(self):
return self.len
def __getitem__(self, index):
frame_start = self.frames_per_sample * index + self.skip_frames
frame_stop = frame_start + self.frames_per_sample
signal_start = frame_start * self.frame_size - self.skip
signal_stop = frame_stop * self.frame_size - self.skip
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
features, periods = self.create_features(
coded_signal,
coded_signal_history,
self.lpcs[frame_start : frame_stop],
self.gains[frame_start : frame_stop],
self.ltps[frame_start : frame_stop],
self.periods[frame_start : frame_stop],
self.offsets[frame_start : frame_stop]
)
lpcnet_features = self.lpcnet_features[frame_start // 2 : frame_stop // 2, :20]
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
return {
'silk_features' : features,
'periods' : periods.astype(np.int64),
'numbits' : numbits.astype(np.float32),
'lpcnet_features' : lpcnet_features
}

View File

@@ -0,0 +1,140 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
from torch.utils.data import Dataset
import numpy as np
from utils.silk_features import silk_feature_factory
from utils.pitch import hangover, calculate_acorr_window
class SilkEnhancementSet(Dataset):
def __init__(self,
path,
frames_per_sample=100,
no_pitch_value=256,
preemph=0.85,
skip=91,
acorr_radius=2,
pitch_hangover=8,
num_bands_clean_spec=64,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
add_double_lag_acorr=False,
):
assert frames_per_sample % 4 == 0
self.frame_size = 80
self.frames_per_sample = frames_per_sample
self.no_pitch_value = no_pitch_value
self.preemph = preemph
self.skip = skip
self.acorr_radius = acorr_radius
self.pitch_hangover = pitch_hangover
self.num_bands_clean_spec = num_bands_clean_spec
self.num_bands_noisy_spec = num_bands_noisy_spec
self.noisy_spec_scale = noisy_spec_scale
self.add_double_lag_acorr = add_double_lag_acorr
self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
self.clean_signal_hp = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
self.clean_signal = np.fromfile(os.path.join(path, 'clean.s16'), dtype=np.int16)
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
self.create_features = silk_feature_factory(no_pitch_value,
acorr_radius,
pitch_hangover,
num_bands_clean_spec,
num_bands_noisy_spec,
noisy_spec_scale,
noisy_apply_dct,
add_double_lag_acorr)
self.history_len = 700 if add_double_lag_acorr else 350
# discard some frames to have enough signal history
self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2)
num_frames = self.clean_signal_hp.shape[0] // 80 - self.skip_frames
self.len = num_frames // frames_per_sample
def __len__(self):
return self.len
def __getitem__(self, index):
frame_start = self.frames_per_sample * index + self.skip_frames
frame_stop = frame_start + self.frames_per_sample
signal_start = frame_start * self.frame_size - self.skip
signal_stop = frame_stop * self.frame_size - self.skip
clean_signal_hp = self.clean_signal_hp[signal_start : signal_stop].astype(np.float32) / 2**15
clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
features, periods = self.create_features(
coded_signal,
coded_signal_history,
self.lpcs[frame_start : frame_stop],
self.gains[frame_start : frame_stop],
self.ltps[frame_start : frame_stop],
self.periods[frame_start : frame_stop]
)
if self.preemph > 0:
clean_signal[1:] -= self.preemph * clean_signal[: -1]
clean_signal_hp[1:] -= self.preemph * clean_signal_hp[: -1]
coded_signal[1:] -= self.preemph * coded_signal[: -1]
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
return {
'features' : features,
'periods' : periods.astype(np.int64),
'target_orig' : clean_signal.astype(np.float32),
'target' : clean_signal_hp.astype(np.float32),
'signals' : coded_signal.reshape(-1, 1).astype(np.float32),
'numbits' : numbits.astype(np.float32)
}

View File

@@ -0,0 +1,103 @@
import torch
from tqdm import tqdm
import sys
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
model.to(device)
model.train()
running_loss = 0
previous_running_loss = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
# calculate loss
if isinstance(output, list):
loss = torch.zeros(1, device=device)
for y in output:
loss = loss + criterion(target, y.squeeze(1))
loss = loss / len(output)
else:
loss = criterion(target, output.squeeze(1))
# calculate gradients
loss.backward()
# update weights
optimizer.step()
# update learning rate
scheduler.step()
# sparsification
if hasattr(model, 'sparsifier'):
model.sparsifier()
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss
def evaluate(model, criterion, dataloader, device, log_interval=10):
model.to(device)
model.eval()
running_loss = 0
previous_running_loss = 0
with torch.no_grad():
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
# calculate loss
loss = criterion(target, output.squeeze(1))
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss

View File

@@ -0,0 +1,101 @@
import torch
from tqdm import tqdm
import sys
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
model.to(device)
model.train()
running_loss = 0
previous_running_loss = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['features'], batch['periods'])
# calculate loss
if isinstance(output, list):
loss = torch.zeros(1, device=device)
for y in output:
loss = loss + criterion(target, y.squeeze(1))
loss = loss / len(output)
else:
loss = criterion(target, output.squeeze(1))
# calculate gradients
loss.backward()
# update weights
optimizer.step()
# update learning rate
scheduler.step()
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss
def evaluate(model, criterion, dataloader, device, log_interval=10):
model.to(device)
model.eval()
running_loss = 0
previous_running_loss = 0
with torch.no_grad():
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['features'], batch['periods'])
# calculate loss
loss = criterion(target, output.squeeze(1))
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss

View File

@@ -0,0 +1,174 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
import hashlib
sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
import torch
import wexchange.torch
from wexchange.torch import dump_torch_weights
from models import model_dict
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.misc import remove_all_weight_norm
from wexchange.torch import dump_torch_weights
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint')
parser.add_argument('output_dir', type=str, help='output folder')
parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
sparse_default=False
schedules = {
'nolace': [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None)),
('tdshape1', dict(quantize=True, scale=None)),
('tdshape2', dict(quantize=True, scale=None)),
('tdshape3', dict(quantize=True, scale=None)),
('af2', dict(quantize=True, scale=None)),
('af3', dict(quantize=True, scale=None)),
('af4', dict(quantize=True, scale=None)),
('post_cf1', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_cf2', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_af1', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_af2', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_af3', dict(quantize=True, scale=None, sparse=sparse_default))
],
'lace' : [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None))
]
}
# auxiliary functions
def sha1(filename):
BUF_SIZE = 65536
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
while True:
data = f.read(BUF_SIZE)
if not data:
break
sha1.update(data)
return sha1.hexdigest()
def osce_dump_generic(writer, name, module):
if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv1d) \
or isinstance(module, torch.nn.ConvTranspose1d) or isinstance(module, torch.nn.Embedding) \
or isinstance(module, LimitedAdaptiveConv1d) or isinstance(module, LimitedAdaptiveComb1d) \
or isinstance(module, TDShaper) or isinstance(module, torch.nn.GRU):
dump_torch_weights(writer, module, name=name, verbose=True)
else:
for child_name, child in module.named_children():
osce_dump_generic(writer, (name + "_" + child_name).replace("feature_net", "fnet"), child)
def export_name(name):
name = name.replace('.', '_')
name = name.replace('feature_net', 'fnet')
return name
def osce_scheduled_dump(writer, prefix, model, schedule):
if not prefix.endswith('_'):
prefix += '_'
for name, kwargs in schedule:
dump_torch_weights(writer, model.get_submodule(name), prefix + export_name(name), **kwargs, verbose=True)
if __name__ == "__main__":
args = parser.parse_args()
checkpoint_path = args.checkpoint
outdir = args.output_dir
os.makedirs(outdir, exist_ok=True)
# dump message
message = f"Auto generated from checkpoint {os.path.basename(checkpoint_path)} (sha1: {sha1(checkpoint_path)})"
# create model and load weights
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
model.load_state_dict(checkpoint['state_dict'])
remove_all_weight_norm(model, verbose=True)
# CWriter
model_name = checkpoint['setup']['model']['name']
cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper() + 'Layers', add_typedef=True)
# Add custom includes and global parameters
cwriter.header.write(f'''
#define {model_name.upper()}_PREEMPH {model.preemph}f
#define {model_name.upper()}_FRAME_SIZE {model.FRAME_SIZE}
#define {model_name.upper()}_OVERLAP_SIZE 40
#define {model_name.upper()}_NUM_FEATURES {model.num_features}
#define {model_name.upper()}_PITCH_MAX {model.pitch_max}
#define {model_name.upper()}_PITCH_EMBEDDING_DIM {model.pitch_embedding_dim}
#define {model_name.upper()}_NUMBITS_RANGE_LOW {model.numbits_range[0]}
#define {model_name.upper()}_NUMBITS_RANGE_HIGH {model.numbits_range[1]}
#define {model_name.upper()}_NUMBITS_EMBEDDING_DIM {model.numbits_embedding_dim}
#define {model_name.upper()}_COND_DIM {model.cond_dim}
#define {model_name.upper()}_HIDDEN_FEATURE_DIM {model.hidden_feature_dim}
''')
for i, s in enumerate(model.numbits_embedding.scale_factors):
cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}f\n")
# dump layers
if model_name in schedules and args.quantize:
osce_scheduled_dump(cwriter, model_name, model, schedules[model_name])
else:
osce_dump_generic(cwriter, model_name, model)
cwriter.close()

View File

@@ -0,0 +1,277 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
"""STFT-based Loss modules."""
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import torchaudio
def get_window(win_name, win_length, *args, **kwargs):
window_dict = {
'bartlett_window' : torch.bartlett_window,
'blackman_window' : torch.blackman_window,
'hamming_window' : torch.hamming_window,
'hann_window' : torch.hann_window,
'kaiser_window' : torch.kaiser_window
}
if not win_name in window_dict:
raise ValueError()
return window_dict[win_name](win_length, *args, **kwargs)
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
win = get_window(window, win_length).to(x.device)
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
return torch.clamp(torch.abs(x_stft), min=1e-7)
def spectral_convergence_loss(Y_true, Y_pred):
dims=list(range(1, len(Y_pred.shape)))
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
def log_magnitude_loss(Y_true, Y_pred):
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
def spectral_xcorr_loss(Y_true, Y_pred):
Y_true = Y_true.abs()
Y_pred = Y_pred.abs()
dims=list(range(1, len(Y_pred.shape)))
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
return 1 - xcorr.mean()
class MRLogMelLoss(nn.Module):
def __init__(self,
fft_sizes=[512, 256, 128, 64],
overlap=0.5,
fs=16000,
n_mels=18
):
self.fft_sizes = fft_sizes
self.overlap = overlap
self.fs = fs
self.n_mels = n_mels
super().__init__()
self.mel_specs = []
for fft_size in fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
n_mels = self.n_mels
if fft_size < 128:
n_mels //= 2
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
for i, mel_spec in enumerate(self.mel_specs):
self.add_module(f'mel_spec_{i+1}', mel_spec)
def forward(self, y_true, y_pred):
loss = torch.zeros(1, device=y_true.device)
for mel_spec in self.mel_specs:
Y_true = mel_spec(y_true)
Y_pred = mel_spec(y_pred)
loss = loss + log_magnitude_loss(Y_true, Y_pred)
loss = loss / len(self.mel_specs)
return loss
def create_weight_matrix(num_bins, bins_per_band=10):
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
r0 = bins_per_band // 2
r1 = bins_per_band - r0
for i in range(num_bins):
i0 = max(i - r0, 0)
j0 = min(i + r1, num_bins)
m[i, i0: j0] += 1
if i < r0:
m[i, :r0 - i] += 1
if i > num_bins - r1:
m[i, num_bins - r1 - i:] += 1
return m / bins_per_band
def weighted_spectral_convergence(Y_true, Y_pred, w):
# calculate sfm based weights
logY = torch.log(torch.abs(Y_true) + 1e-9)
Y = torch.abs(Y_true)
avg_logY = torch.matmul(logY.transpose(1, 2), w)
avg_Y = torch.matmul(Y.transpose(1, 2), w)
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
loss = torch.mean(
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
)
return loss
def gen_filterbank(N, Fs=16000):
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
ERB_N = 24.7 + .108*in_freq
delta = np.abs(in_freq-out_freq)/ERB_N
center = (delta<.5).astype('float32')
R = -12*center*delta**2 + (1-center)*(3-12*delta)
RE = 10.**(R/10.)
norm = np.sum(RE, axis=1)
RE = RE/norm[:, np.newaxis]
return torch.from_numpy(RE)
def smooth_log_mag(Y_true, Y_pred, filterbank):
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
loss = torch.abs(
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
)
loss = loss.mean()
return loss
class MRSTFTLoss(nn.Module):
def __init__(self,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
overlap=0.5,
window='hann_window',
fs=16000,
log_mag_weight=1,
sc_weight=0,
wsc_weight=0,
smooth_log_mag_weight=0,
sxcorr_weight=0):
super().__init__()
self.fft_sizes = fft_sizes
self.overlap = overlap
self.window = window
self.log_mag_weight = log_mag_weight
self.sc_weight = sc_weight
self.wsc_weight = wsc_weight
self.smooth_log_mag_weight = smooth_log_mag_weight
self.sxcorr_weight = sxcorr_weight
self.fs = fs
# weights for SFM weighted spectral convergence loss
self.wsc_weights = torch.nn.ParameterDict()
for fft_size in fft_sizes:
width = min(11, int(1000 * fft_size / self.fs + .5))
width += width % 2
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
create_weight_matrix(fft_size // 2 + 1, width),
requires_grad=False
)
# filterbanks for smooth log magnitude loss
self.filterbanks = torch.nn.ParameterDict()
for fft_size in fft_sizes:
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
gen_filterbank(fft_size//2),
requires_grad=False
)
def __call__(self, y_true, y_pred):
lm_loss = torch.zeros(1, device=y_true.device)
sc_loss = torch.zeros(1, device=y_true.device)
wsc_loss = torch.zeros(1, device=y_true.device)
slm_loss = torch.zeros(1, device=y_true.device)
sxcorr_loss = torch.zeros(1, device=y_true.device)
for fft_size in self.fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
win_size = fft_size
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
if self.log_mag_weight > 0:
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
if self.sc_weight > 0:
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
if self.wsc_weight > 0:
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
if self.smooth_log_mag_weight > 0:
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
if self.sxcorr_weight > 0:
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
return total_loss

View File

@@ -0,0 +1,34 @@
import torch
import scipy.signal
from utils.layers.fir import FIR
class TDLowpass(torch.nn.Module):
def __init__(self, numtaps, cutoff, power=2):
super().__init__()
self.b = scipy.signal.firwin(numtaps, cutoff)
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
self.power = power
def forward(self, y_true, y_pred):
assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
diff = y_true - y_pred
diff_lp = torch.nn.functional.conv1d(diff, self.weight)
loss = torch.mean(torch.abs(diff_lp ** self.power))
return loss, diff_lp
def get_freqz(self):
freq, response = scipy.signal.freqz(self.b)
return freq, response

Some files were not shown because too many files have changed in this diff Show More