add some code

This commit is contained in:
2025-09-05 13:25:11 +08:00
parent 9ff0a99e7a
commit 3cf1229a85
8911 changed files with 2535396 additions and 0 deletions

View File

@@ -0,0 +1,71 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jean-Marc Valin */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
# x is (batch, nb_in_channels, nb_frames*frame_size)
# kernels is (batch, nb_out_channels, nb_in_channels, nb_frames, coeffs)
def adaconv_kernel(x, kernels, half_window, fft_size=256):
device=x.device
overlap_size=half_window.size(-1)
nb_frames=kernels.size(3)
nb_batches=kernels.size(0)
nb_out_channels=kernels.size(1)
nb_in_channels=kernels.size(2)
kernel_size = kernels.size(-1)
x = x.reshape(nb_batches, 1, nb_in_channels, nb_frames, -1)
frame_size = x.size(-1)
# build window: [zeros, rising window, ones, falling window, zeros]
window = torch.cat(
[
torch.zeros(frame_size, device=device),
half_window,
torch.ones(frame_size - overlap_size, device=device),
1 - half_window,
torch.zeros(fft_size - 2 * frame_size - overlap_size,device=device)
])
x_prev = torch.cat([torch.zeros_like(x[:, :, :, :1, :]), x[:, :, :, :-1, :]], dim=-2)
x_next = torch.cat([x[:, :, :, 1:, :overlap_size], torch.zeros_like(x[:, :, :, -1:, :overlap_size])], dim=-2)
x_padded = torch.cat([x_prev, x, x_next, torch.zeros(nb_batches, 1, nb_in_channels, nb_frames, fft_size - 2 * frame_size - overlap_size, device=device)], -1)
k_padded = torch.cat([torch.flip(kernels, [-1]), torch.zeros(nb_batches, nb_out_channels, nb_in_channels, nb_frames, fft_size-kernel_size, device=device)], dim=-1)
# compute convolution
X = torch.fft.rfft(x_padded, dim=-1)
K = torch.fft.rfft(k_padded, dim=-1)
out = torch.fft.irfft(X * K, dim=-1)
# combine in channels
out = torch.sum(out, dim=2)
# apply the cross-fading
out = window.reshape(1, 1, 1, -1)*out
crossfaded = out[:,:,:,frame_size:2*frame_size] + torch.cat([torch.zeros(nb_batches, nb_out_channels, 1, frame_size, device=device), out[:, :, :-1, 2*frame_size:3*frame_size]], dim=-2)
return crossfaded.reshape(nb_batches, nb_out_channels, -1)

View File

@@ -0,0 +1,8 @@
def _conv1d_flop_count(layer, rate):
return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0]
def _dense_flop_count(layer, rate):
return 2 * ((layer.in_features + 1) * layer.out_features * rate )

View File

@@ -0,0 +1,205 @@
""" module for inspecting models during inference """
import os
import yaml
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import numpy as np
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
_state = dict()
_folder = 'endoscopy'
def get_gru_gates(gru, input, state):
hidden_size = gru.hidden_size
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
# reset gate
start, stop = 0 * hidden_size, 1 * hidden_size
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# update gate
start, stop = 1 * hidden_size, 2 * hidden_size
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# new gate
start, stop = 2 * hidden_size, 3 * hidden_size
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
def init(folder='endoscopy'):
""" sets up output folder for endoscopy data """
global _folder
_folder = folder
if not os.path.exists(folder):
os.makedirs(folder)
else:
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
def write_data(key, data, fs):
""" appends data to previous data written under key """
global _state
# convert to numpy if torch.Tensor is given
if isinstance(data, torch.Tensor):
data = data.detach().numpy()
if not key in _state:
_state[key] = {
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
'fs' : fs,
'dim' : tuple(data.shape),
'dtype' : str(data.dtype)
}
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
else:
if _state[key]['fs'] != fs:
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
if _state[key]['dtype'] != str(data.dtype):
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
if _state[key]['dim'] != tuple(data.shape):
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
_state[key]['fid'].write(data.tobytes())
def close(folder='endoscopy'):
""" clean up """
for key in _state.keys():
_state[key]['fid'].close()
def read_data(folder='endoscopy'):
""" retrieves written data as numpy arrays """
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
return_dict = dict()
for key in keys:
with open(os.path.join(folder, key + '.yml'), 'r') as f:
value = yaml.load(f.read(), yaml.FullLoader)
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
data = np.frombuffer(f.read(), dtype=value['dtype'])
value['data'] = data.reshape((-1,) + value['dim'])
return_dict[key] = value
return return_dict
def get_best_reshape(shape, target_ratio=1):
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return (1,)
num_columns = int((pixel_count / target_ratio)**.5)
while (pixel_count % num_columns):
num_columns -= 1
num_rows = pixel_count // num_columns
return (num_rows, num_columns)
def get_type_and_shape(shape):
# can happen if data is one dimensional
if len(shape) == 0:
shape = (1,)
# calculate pixel count
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return 'plot', (1, )
# stay with shape if already 2-dimensional
if len(shape) == 2:
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
return 'image', shape
return 'image', get_best_reshape(shape)
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
# determine plot setup
num_keys = len(data.keys())
num_rows = int((num_keys * 3/4) ** .5)
num_cols = (num_keys + num_rows - 1) // num_rows
fig, axs = plt.subplots(num_rows, num_cols)
fig.set_size_inches(num_cols * 5, num_rows * 5)
display = dict()
fs_max = max([val['fs'] for val in data.values()])
num_samples = max([val['data'].shape[0] for val in data.values()])
keys = sorted(data.keys())
# inspect data
for i, key in enumerate(keys):
axs[i // num_cols, i % num_cols].title.set_text(key)
display[key] = dict()
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
display[key]['down_factor'] = data[key]['fs'] / fs_max
start_index = max(start_index, half_signal_window_length)
while stop_index < 0:
stop_index += num_samples
stop_index = min(stop_index, num_samples - half_signal_window_length)
# actual plotting
frames = []
for index in range(start_index, stop_index):
ims = []
for i, key in enumerate(keys):
feature_index = int(round(index * display[key]['down_factor']))
if display[key]['type'] == 'plot':
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
elif display[key]['type'] == 'image':
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
frames.append(ims)
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
if not filename.endswith('.mp4'):
filename += '.mp4'
ani.save(filename)

View File

@@ -0,0 +1,27 @@
import numpy as np
import scipy.signal
import torch
from torch import nn
import torch.nn.functional as F
class FIR(nn.Module):
def __init__(self, numtaps, bands, desired, fs=2):
super().__init__()
if numtaps % 2 == 0:
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
numtaps += 1
a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
self.weight = torch.from_numpy(a.astype(np.float32))
def forward(self, x):
num_channels = x.size(1)
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
y = F.conv1d(x, weight, groups=num_channels)
return y

View File

@@ -0,0 +1,230 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.endoscopy import write_data
from utils.softquant import soft_quant
class LimitedAdaptiveComb1d(nn.Module):
COUNTER = 1
def __init__(self,
kernel_size,
feature_dim,
frame_size=160,
overlap_size=40,
padding=None,
max_lag=256,
name=None,
gain_limit_db=10,
global_gain_limits_db=[-6, 6],
norm_p=2,
softquant=False,
apply_weight_norm=False,
**kwargs):
"""
Parameters:
-----------
feature_dim : int
dimension of features from which kernels, biases and gains are computed
frame_size : int, optional
frame size, defaults to 160
overlap_size : int, optional
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40
use_bias : bool, optional
if true, biases will be added to output channels. Defaults to True
padding : List[int, int], optional
left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2]
max_lag : int, optional
maximal pitch lag, defaults to 256
have_a0 : bool, optional
If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False
name: str or None, optional
specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d
"""
super(LimitedAdaptiveComb1d, self).__init__()
self.in_channels = 1
self.out_channels = 1
self.feature_dim = feature_dim
self.kernel_size = kernel_size
self.frame_size = frame_size
self.overlap_size = overlap_size
self.max_lag = max_lag
self.limit_db = gain_limit_db
self.norm_p = norm_p
if name is None:
self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER)
LimitedAdaptiveComb1d.COUNTER += 1
else:
self.name = name
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
# network for generating convolution weights
self.conv_kernel = norm(nn.Linear(feature_dim, kernel_size))
if softquant:
self.conv_kernel = soft_quant(self.conv_kernel)
# comb filter gain
self.filter_gain = norm(nn.Linear(feature_dim, 1))
self.log_gain_limit = gain_limit_db * 0.11512925464970229
with torch.no_grad():
self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
self.global_filter_gain = norm(nn.Linear(feature_dim, 1))
log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
self.filter_gain_a = (log_max - log_min) / 2
self.filter_gain_b = (log_max + log_min) / 2
if type(padding) == type(None):
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
else:
self.padding = padding
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
def forward(self, x, features, lags, debug=False):
""" adaptive 1d convolution
Parameters:
-----------
x : torch.tensor
input signal of shape (batch_size, in_channels, num_samples)
feathres : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
lags: torch.LongTensor
frame-wise lags for comb-filtering
"""
batch_size = x.size(0)
num_frames = features.size(1)
num_samples = x.size(2)
frame_size = self.frame_size
overlap_size = self.overlap_size
kernel_size = self.kernel_size
win1 = torch.flip(self.overlap_win, [0])
win2 = self.overlap_win
if num_samples // self.frame_size != num_frames:
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
# calculate gains
global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
if debug and batch_size == 1:
key = self.name + "_gains"
write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_kernels"
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_lags"
write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_global_conv_gains"
write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
# frame-wise convolution with overlap-add
output_frames = []
overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
x = F.pad(x, self.padding)
x = F.pad(x, [self.max_lag, self.overlap_size])
idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1)
idx = torch.repeat_interleave(idx, batch_size, 0)
idx = torch.repeat_interleave(idx, self.in_channels, 1)
for i in range(num_frames):
cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1)
xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1))
new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
offset = self.max_lag + self.padding[0]
new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
# overlapping part
output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
# non-overlapping part
output_frames.append(new_chunk[:, :, overlap_size : frame_size])
# mem for next frame
overlap_mem = new_chunk[:, :, frame_size :]
# concatenate chunks
output = torch.cat(output_frames, dim=-1)
return output
def flop_count(self, rate):
frame_rate = rate / self.frame_size
overlap = self.overlap_size
overhead = overlap / self.frame_size
count = 0
# kernel computation and filtering
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
# a0 computation
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
# windowing
count += overlap * frame_rate * 3 * self.out_channels
return count

View File

@@ -0,0 +1,200 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.endoscopy import write_data
from utils.ada_conv import adaconv_kernel
from utils.softquant import soft_quant
class LimitedAdaptiveConv1d(nn.Module):
COUNTER = 1
def __init__(self,
in_channels,
out_channels,
kernel_size,
feature_dim,
frame_size=160,
overlap_size=40,
padding=None,
name=None,
gain_limits_db=[-6, 6],
shape_gain_db=0,
norm_p=2,
softquant=False,
apply_weight_norm=False,
**kwargs):
"""
Parameters:
-----------
in_channels : int
number of input channels
out_channels : int
number of output channels
feature_dim : int
dimension of features from which kernels, biases and gains are computed
frame_size : int
frame size
overlap_size : int
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame
use_bias : bool
if true, biases will be added to output channels
padding : List[int, int]
"""
super(LimitedAdaptiveConv1d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.feature_dim = feature_dim
self.kernel_size = kernel_size
self.frame_size = frame_size
self.overlap_size = overlap_size
self.gain_limits_db = gain_limits_db
self.shape_gain_db = shape_gain_db
self.norm_p = norm_p
if name is None:
self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER)
LimitedAdaptiveConv1d.COUNTER += 1
else:
self.name = name
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
# network for generating convolution weights
self.conv_kernel = norm(nn.Linear(feature_dim, in_channels * out_channels * kernel_size))
if softquant:
self.conv_kernel = soft_quant(self.conv_kernel)
self.shape_gain = min(1, 10**(shape_gain_db / 20))
self.filter_gain = norm(nn.Linear(feature_dim, out_channels))
log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
self.filter_gain_a = (log_max - log_min) / 2
self.filter_gain_b = (log_max + log_min) / 2
if type(padding) == type(None):
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
else:
self.padding = padding
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
def flop_count(self, rate):
frame_rate = rate / self.frame_size
overlap = self.overlap_size
overhead = overlap / self.frame_size
count = 0
# kernel computation and filtering
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
# gain computation
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
# windowing
count += 3 * overlap * frame_rate * self.out_channels
return count
def forward(self, x, features, debug=False):
""" adaptive 1d convolution
Parameters:
-----------
x : torch.tensor
input signal of shape (batch_size, in_channels, num_samples)
feathres : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
"""
batch_size = x.size(0)
num_frames = features.size(1)
num_samples = x.size(2)
frame_size = self.frame_size
overlap_size = self.overlap_size
kernel_size = self.kernel_size
win1 = torch.flip(self.overlap_win, [0])
win2 = self.overlap_win
if num_samples // self.frame_size != num_frames:
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
# normalize kernels (TODO: switch to L1 and normalize over kernel and input channel dimension)
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=[-2, -1], keepdim=True))
# limit shape
id_kernels = torch.zeros_like(conv_kernels)
id_kernels[..., self.padding[1]] = 1
conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels
# calculate gains
conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features)) + self.filter_gain_b)
if debug and batch_size == 1:
key = self.name + "_gains"
write_data(key, conv_gains.permute(0, 2, 1).detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_kernels"
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
conv_kernels = conv_kernels * conv_gains.view(batch_size, num_frames, self.out_channels, 1, 1)
conv_kernels = conv_kernels.permute(0, 2, 3, 1, 4)
output = adaconv_kernel(x, conv_kernels, win1, fft_size=256)
return output

View File

@@ -0,0 +1,100 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
class NoiseShaper(nn.Module):
def __init__(self,
feature_dim,
frame_size=160
):
"""
Parameters:
-----------
feature_dim : int
dimension of input features
frame_size : int
frame size
"""
super().__init__()
self.feature_dim = feature_dim
self.frame_size = frame_size
# feature transform
self.feature_alpha1 = nn.Conv1d(self.feature_dim, frame_size, 2)
self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2)
def flop_count(self, rate):
frame_rate = rate / self.frame_size
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
return shape_flops
def forward(self, features):
""" creates temporally shaped noise
Parameters:
-----------
features : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
"""
batch_size = features.size(0)
num_frames = features.size(1)
frame_size = self.frame_size
num_samples = num_frames * frame_size
# feature path
f = F.pad(features.permute(0, 2, 1), [1, 0])
alpha = F.leaky_relu(self.feature_alpha1(f), 0.2)
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
alpha = alpha.permute(0, 2, 1)
# signal generation
y = torch.randn((batch_size, num_frames, frame_size), dtype=features.dtype, device=features.device)
y = alpha * y
return y.reshape(batch_size, 1, num_samples)

View File

@@ -0,0 +1,84 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
class PitchAutoCorrelator(nn.Module):
def __init__(self,
frame_size=80,
pitch_min=32,
pitch_max=300,
radius=2):
super().__init__()
self.frame_size = frame_size
self.pitch_min = pitch_min
self.pitch_max = pitch_max
self.radius = radius
def forward(self, x, periods):
# x of shape (batch_size, channels, num_samples)
# periods of shape (batch_size, num_frames)
num_frames = periods.size(1)
batch_size = periods.size(0)
num_samples = self.frame_size * num_frames
channels = x.size(1)
assert num_samples == x.size(-1)
range = torch.arange(-self.radius, self.radius + 1, device=x.device)
idx = torch.arange(self.frame_size * num_frames, device=x.device)
p_up = torch.repeat_interleave(periods, self.frame_size, 1)
lookup = idx + self.pitch_max - p_up
lookup = lookup.unsqueeze(-1) + range
lookup = lookup.unsqueeze(1)
# padding
x_pad = F.pad(x, [self.pitch_max, 0])
x_ext = torch.repeat_interleave(x_pad.unsqueeze(-1), 2 * self.radius + 1, -1)
# framing
x_select = torch.gather(x_ext, 2, lookup)
x_frames = x_pad[..., self.pitch_max : ].reshape(batch_size, channels, num_frames, self.frame_size, 1)
lag_frames = x_select.reshape(batch_size, 1, num_frames, self.frame_size, -1)
# calculate auto-correlation
dotp = torch.sum(x_frames * lag_frames, dim=-2)
frame_nrg = torch.sum(x_frames * x_frames, dim=-2)
lag_frame_nrg = torch.sum(lag_frames * lag_frames, dim=-2)
acorr = dotp / torch.sqrt(frame_nrg * lag_frame_nrg + 1e-9)
return acorr

View File

@@ -0,0 +1,167 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
frac_fir = np.array(
[
[189, -600, 617, 30567, 2996, -1375, 425, -46],
[117, -159, -1070, 29704, 5784, -2143, 611, -71],
[52, 221, -2392, 28276, 8798, -2865, 773, -91],
[-4, 529, -3350, 26341, 11950, -3487, 896, -103],
[-48, 758, -3956, 23973, 15143, -3957, 967, -107],
[-80, 905, -4235, 21254, 18278, -4222, 972, -99],
[-99, 972, -4222, 18278, 21254, -4235, 905, -80],
[-107, 967, -3957, 15143, 23973, -3956, 758, -48],
[-103, 896, -3487, 11950, 26341, -3350, 529, -4],
[-91, 773, -2865, 8798, 28276, -2392, 221, 52],
[-71, 611, -2143, 5784, 29704, -1070, -159, 117],
[-46, 425, -1375, 2996, 30567, 617, -600, 189]
],
dtype=np.float32
) / 2**15
hq_2x_up_c_even = [x / 2**16 for x in [1746, 14986, 39083 - 65536]]
hq_2x_up_c_odd = [x / 2**16 for x in [6854, 25769, 55542 - 65536]]
def get_impz(coeffs, n):
s = 3*[0]
y = np.zeros(n)
x = 1
for i in range(n):
Y = x - s[0]
X = Y * coeffs[0]
tmp1 = s[0] + X
s[0] = x + X
Y = tmp1 - s[1]
X = Y * coeffs[1]
tmp2 = s[1] + X
s[1] = tmp1 + X
Y = tmp2 - s[2]
X = Y * (1 + coeffs[2])
tmp3 = s[2] + X
s[2] = tmp2 + X
y[i] = tmp3
x = 0
return y
class SilkUpsampler(nn.Module):
SUPPORTED_TARGET_RATES = {24000, 48000}
SUPPORTED_SOURCE_RATES = {16000}
def __init__(self,
fs_in=16000,
fs_out=48000):
super().__init__()
self.fs_in = fs_in
self.fs_out = fs_out
if fs_in not in self.SUPPORTED_SOURCE_RATES:
raise ValueError(f'SilkUpsampler currently only supports upsampling from {self.SUPPORTED_SOURCE_RATES} Hz')
if fs_out not in self.SUPPORTED_TARGET_RATES:
raise ValueError(f'SilkUpsampler currently only supports upsampling to {self.SUPPORTED_TARGET_RATES} Hz')
# hq 2x upsampler as FIR approximation
hq_2x_up_even = get_impz(hq_2x_up_c_even, 128)[::-1].copy()
hq_2x_up_odd = get_impz(hq_2x_up_c_odd , 128)[::-1].copy()
self.hq_2x_up_even = nn.Parameter(torch.from_numpy(hq_2x_up_even).float().view(1, 1, -1), requires_grad=False)
self.hq_2x_up_odd = nn.Parameter(torch.from_numpy(hq_2x_up_odd ).float().view(1, 1, -1), requires_grad=False)
self.hq_2x_up_padding = [127, 0]
# interpolation filters
frac_01_24 = frac_fir[0]
frac_17_24 = frac_fir[8]
frac_09_24 = frac_fir[4]
self.frac_01_24 = nn.Parameter(torch.from_numpy(frac_01_24).view(1, 1, -1), requires_grad=False)
self.frac_17_24 = nn.Parameter(torch.from_numpy(frac_17_24).view(1, 1, -1), requires_grad=False)
self.frac_09_24 = nn.Parameter(torch.from_numpy(frac_09_24).view(1, 1, -1), requires_grad=False)
self.stride = 1 if fs_out == 48000 else 2
def hq_2x_up(self, x):
num_channels = x.size(1)
weight_even = torch.repeat_interleave(self.hq_2x_up_even, num_channels, 0)
weight_odd = torch.repeat_interleave(self.hq_2x_up_odd , num_channels, 0)
x_pad = F.pad(x, self.hq_2x_up_padding)
y_even = F.conv1d(x_pad, weight_even, groups=num_channels)
y_odd = F.conv1d(x_pad, weight_odd , groups=num_channels)
y = torch.cat((y_even.unsqueeze(-1), y_odd.unsqueeze(-1)), dim=-1).flatten(2)
return y
def interpolate_3_2(self, x):
num_channels = x.size(1)
weight_01_24 = torch.repeat_interleave(self.frac_01_24, num_channels, 0)
weight_17_24 = torch.repeat_interleave(self.frac_17_24, num_channels, 0)
weight_09_24 = torch.repeat_interleave(self.frac_09_24, num_channels, 0)
x_pad = F.pad(x, [8, 0])
y_01_24 = F.conv1d(x_pad, weight_01_24, stride=2, groups=num_channels)
y_17_24 = F.conv1d(x_pad, weight_17_24, stride=2, groups=num_channels)
y_09_24_sh1 = F.conv1d(torch.roll(x_pad, -1, -1), weight_09_24, stride=2, groups=num_channels)
y = torch.cat(
(y_01_24.unsqueeze(-1), y_17_24.unsqueeze(-1), y_09_24_sh1.unsqueeze(-1)),
dim=-1).flatten(2)
return y[..., :-3]
def forward(self, x):
y_2x = self.hq_2x_up(x)
y_3x = self.interpolate_3_2(y_2x)
return y_3x[:, :, ::self.stride]

View File

@@ -0,0 +1,145 @@
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
from utils.softquant import soft_quant
class TDShaper(nn.Module):
COUNTER = 1
def __init__(self,
feature_dim,
frame_size=160,
avg_pool_k=4,
innovate=False,
pool_after=False,
softquant=False,
apply_weight_norm=False
):
"""
Parameters:
-----------
feature_dim : int
dimension of input features
frame_size : int
frame size
avg_pool_k : int, optional
kernel size and stride for avg pooling
padding : List[int, int]
"""
super().__init__()
self.feature_dim = feature_dim
self.frame_size = frame_size
self.avg_pool_k = avg_pool_k
self.innovate = innovate
self.pool_after = pool_after
assert frame_size % avg_pool_k == 0
self.env_dim = frame_size // avg_pool_k + 1
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
# feature transform
self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))
self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2))
if softquant:
self.feature_alpha1_f = soft_quant(self.feature_alpha1_f)
if self.innovate:
self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2))
self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2))
def flop_count(self, rate):
frame_rate = rate / self.frame_size
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
if self.innovate:
inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
else:
inno_flops = 0
return shape_flops + inno_flops
def envelope_transform(self, x):
x = torch.abs(x)
if self.pool_after:
x = torch.log(x + .5**16)
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
else:
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
x = torch.log(x + .5**16)
x = x.reshape(x.size(0), -1, self.env_dim - 1)
avg_x = torch.mean(x, -1, keepdim=True)
x = torch.cat((x - avg_x, avg_x), dim=-1)
return x
def forward(self, x, features, debug=False):
""" innovate signal parts with temporal shaping
Parameters:
-----------
x : torch.tensor
input signal of shape (batch_size, 1, num_samples)
features : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
"""
batch_size = x.size(0)
num_frames = features.size(1)
num_samples = x.size(2)
frame_size = self.frame_size
# generate temporal envelope
tenv = self.envelope_transform(x)
# feature path
f = F.pad(features.permute(0, 2, 1), [1, 0])
t = F.pad(tenv.permute(0, 2, 1), [1, 0])
alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t)
alpha = F.leaky_relu(alpha, 0.2)
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
alpha = alpha.permute(0, 2, 1)
if self.innovate:
inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2)
inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0])))
inno_alpha = inno_alpha.permute(0, 2, 1)
inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2)
inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0])))
inno_x = inno_x.permute(0, 2, 1)
# signal path
y = x.reshape(batch_size, num_frames, -1)
y = alpha * y
if self.innovate:
y = y + inno_alpha * inno_x
return y.reshape(batch_size, 1, num_samples)

View File

@@ -0,0 +1,112 @@
import os
import torch
import numpy as np
def load_lpcnet_features(feature_file, version=2):
if version == 2:
layout = {
'cepstrum': [0,18],
'periods': [18, 19],
'pitch_corr': [19, 20],
'lpc': [20, 36]
}
frame_length = 36
elif version == 1:
layout = {
'cepstrum': [0,18],
'periods': [36, 37],
'pitch_corr': [37, 38],
'lpc': [39, 55],
}
frame_length = 55
else:
raise ValueError(f'unknown feature version: {version}')
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
raw_features = raw_features.reshape((-1, frame_length))
features = torch.cat(
[
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
],
dim=1
)
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
ref_data = np.memmap(reference_data_path, dtype=np.int16)
signal = np.memmap(signal_path, dtype=np.int16)
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
assert len(signal) % 160 == 0
num_frames = len(signal) // 160
mem = np.zeros(1)
for fr in range(len(signal)//160):
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
new_data[:] = 0
N = len(signal) - offset
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
def parse_warpq_scores(output_file):
""" extracts warpq scores from output file """
with open(output_file, "r") as f:
lines = f.readlines()
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
return scores
def parse_stats_file(file):
with open(file, "r") as f:
lines = f.readlines()
mean = float(lines[0].split(":")[-1])
bt_mean = float(lines[1].split(":")[-1])
top_mean = float(lines[2].split(":")[-1])
return mean, bt_mean, top_mean
def collect_test_stats(test_folder):
""" collects statistics for all discovered metrics from test folder """
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
results = dict()
content = os.listdir(test_folder)
stats_files = [file for file in content if file.startswith('stats_')]
for file in stats_files:
metric = file[len("stats_") : -len(".txt")]
if metric not in metrics:
print(f"warning: unknown metric {metric}")
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
results[metric] = [mean, bt_mean, top_mean]
return results

View File

@@ -0,0 +1,95 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch.nn.utils import remove_weight_norm
def count_parameters(model, verbose=False):
total = 0
for name, p in model.named_parameters():
count = torch.ones_like(p).sum().item()
if verbose:
print(f"{name}: {count} parameters")
total += count
return total
def count_nonzero_parameters(model, verbose=False):
total = 0
for name, p in model.named_parameters():
count = torch.count_nonzero(p).item()
if verbose:
print(f"{name}: {count} non-zero parameters")
total += count
return total
def retain_grads(module):
for p in module.parameters():
if p.requires_grad:
p.retain_grad()
def get_grad_norm(module, p=2):
norm = 0
for param in module.parameters():
if param.requires_grad:
norm = norm + (torch.abs(param.grad) ** p).sum()
return norm ** (1/p)
def create_weights(s_real, s_gen, alpha):
weights = []
with torch.no_grad():
for sr, sg in zip(s_real, s_gen):
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
weights.append(weight)
return weights
def _get_candidates(module: torch.nn.Module):
candidates = []
for key in module.__dict__.keys():
if hasattr(module, key + '_v'):
candidates.append(key)
return candidates
def remove_all_weight_norm(model : torch.nn.Module, verbose=False):
for name, m in model.named_modules():
candidates = _get_candidates(m)
for candidate in candidates:
try:
remove_weight_norm(m, name=candidate)
if verbose: print(f'removed weight norm on weight {name}.{candidate}')
except:
pass

View File

@@ -0,0 +1,153 @@
import numpy as np
import scipy.signal
def compute_vad_mask(x, fs, stop_db=-70):
frame_length = (fs + 49) // 50
x = x[: frame_length * (len(x) // frame_length)]
frames = x.reshape(-1, frame_length)
frame_energy = np.sum(frames ** 2, axis=1)
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
vactive = np.ones_like(frames)
vactive[frame_energy_smooth < max_threshold, :] = 0
vactive = vactive.reshape(-1)
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
filter = filter / filter.sum()
mask = np.convolve(vactive, filter, mode='same')
return x, mask
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
num_samples = frame_size + (num_frames - 1) * hop_size
if len(mask) < num_samples:
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
else:
mask = mask[:num_samples]
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
return new_mask
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
num_spectra = (len(x) - window_size - hop_size) // hop_size
window = scipy.signal.get_window(window, window_size)
N = window_size // 2
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
return psd
def frequency_mask(num_bands, up_factor, down_factor):
up_mask = np.zeros((num_bands, num_bands))
down_mask = np.zeros((num_bands, num_bands))
for i in range(num_bands):
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
return down_mask @ up_mask
def rect_fb(band_limits, num_bins=None):
num_bands = len(band_limits) - 1
if num_bins is None:
num_bins = band_limits[-1]
fb = np.zeros((num_bands, num_bins))
for i in range(num_bands):
fb[i, band_limits[i]:band_limits[i+1]] = 1
return fb
def compare(x, y, apply_vad=False):
""" Modified version of opus_compare for 16 kHz mono signals
Args:
x (np.ndarray): reference input signal scaled to [-1, 1]
y (np.ndarray): test signal scaled to [-1, 1]
Returns:
float: perceptually weighted error
"""
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
num_bands = len(band_limits) - 1
fb = rect_fb(band_limits, num_bins=81)
# trim samples to same size
num_samples = min(len(x), len(y))
x = x[:num_samples] * 2**15
y = y[:num_samples] * 2**15
psd_x = power_spectrum(x) + 100000
psd_y = power_spectrum(y) + 100000
num_frames = psd_x.shape[0]
# average band energies
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
# frequecy masking
f_mask = frequency_mask(num_bands, 0.1, 0.03)
mask_x = be_x @ f_mask.T
# temporal masking
for i in range(1, num_frames):
mask_x[i, :] += 0.5 * mask_x[i-1, :]
# apply mask
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
# 2-frame average
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
# distortion metric
re = masked_psd_y / masked_psd_x
im = np.log(re) ** 2
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
Ef = np.mean(Eb , axis=1)
if apply_vad:
_, mask = compute_vad_mask(x, 16000)
mask = convert_mask(mask, Ef.shape[0])
else:
mask = np.ones_like(Ef)
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
return float(err)
if __name__ == "__main__":
import argparse
from scipy.io import wavfile
parser = argparse.ArgumentParser()
parser.add_argument('ref', type=str, help='reference wav file')
parser.add_argument('deg', type=str, help='degraded wav file')
parser.add_argument('--apply-vad', action='store_true')
args = parser.parse_args()
fs1, x = wavfile.read(args.ref)
fs2, y = wavfile.read(args.deg)
if max(fs1, fs2) != 16000:
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
x = x.astype(np.float32) / 2**15
y = y.astype(np.float32) / 2**15
err = compare(x, y, apply_vad=args.apply_vad)
print(f"MOC: {err}")

View File

@@ -0,0 +1,122 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numpy as np
def hangover(lags, num_frames=10):
lags = lags.copy()
count = 0
last_lag = 0
for i in range(len(lags)):
lag = lags[i]
if lag == 0:
if count < num_frames:
lags[i] = last_lag
count += 1
else:
count = 0
last_lag = lag
return lags
def smooth_pitch_lags(lags, d=2):
assert d < 4
num_silk_frames = len(lags) // 4
smoothed_lags = lags.copy()
tmp = np.arange(1, d+1)
kernel = np.concatenate((tmp, [d+1], tmp[::-1]), dtype=np.float32)
kernel = kernel / np.sum(kernel)
last = lags[0:d][::-1]
for i in range(num_silk_frames):
frame = lags[i * 4: (i+1) * 4]
if np.max(np.abs(frame)) == 0:
last = frame[4-d:]
continue
if i == num_silk_frames - 1:
next = frame[4-d:][::-1]
else:
next = lags[(i+1) * 4 : (i+1) * 4 + d]
if np.max(np.abs(next)) == 0:
next = frame[4-d:][::-1]
if np.max(np.abs(last)) == 0:
last = frame[0:d][::-1]
smoothed_frame = np.convolve(np.concatenate((last, frame, next), dtype=np.float32), kernel, mode='valid')
smoothed_lags[i * 4: (i+1) * 4] = np.round(smoothed_frame)
last = frame[4-d:]
return smoothed_lags
def calculate_acorr_window(x, frame_size, lags, history=None, max_lag=300, radius=2, add_double_lag_acorr=False, no_pitch_threshold=32):
eps = 1e-9
lag_multiplier = 2 if add_double_lag_acorr else 1
if history is None:
history = np.zeros(lag_multiplier * max_lag + radius, dtype=x.dtype)
offset = len(history)
assert offset >= max_lag + radius
assert len(x) % frame_size == 0
num_frames = len(x) // frame_size
lags = lags.copy()
x_ext = np.concatenate((history, x), dtype=x.dtype)
d = radius
num_acorrs = 2 * d + 1
acorrs = np.zeros((num_frames, lag_multiplier * num_acorrs), dtype=x.dtype)
for idx in range(num_frames):
lag = lags[idx].item()
frame = x_ext[offset + idx * frame_size : offset + (idx + 1) * frame_size]
for k in range(lag_multiplier):
lag1 = (k + 1) * lag if lag >= no_pitch_threshold else lag
for j in range(num_acorrs):
past = x_ext[offset + idx * frame_size - lag1 + j - d : offset + (idx + 1) * frame_size - lag1 + j - d]
acorrs[idx, j + k * num_acorrs] = np.dot(frame, past) / np.sqrt(np.dot(frame, frame) * np.dot(past, past) + eps)
return acorrs, lags

View File

@@ -0,0 +1,144 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import numpy as np
import torch
import scipy
import scipy.signal
from utils.pitch import hangover, calculate_acorr_window
from utils.spec import create_filter_bank, cepstrum, log_spectrum, log_spectrum_from_lpc
def spec_from_lpc(a, n_fft=128, eps=1e-9):
order = a.shape[-1]
assert order + 1 < n_fft
x = np.zeros((*a.shape[:-1], n_fft ))
x[..., 0] = 1
x[..., 1:1 + order] = -a
X = np.fft.fft(x, axis=-1)
X = np.abs(X[..., :n_fft//2 + 1]) ** 2
S = 1 / (X + eps)
return S
def silk_feature_factory(no_pitch_value=256,
acorr_radius=2,
pitch_hangover=8,
num_bands_clean_spec=64,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
add_double_lag_acorr=False
):
w = scipy.signal.windows.cosine(320)
fb_clean_spec = create_filter_bank(num_bands_clean_spec, 320, scale='erb', round_center_bins=True, normalize=True)
fb_noisy_spec = create_filter_bank(num_bands_noisy_spec, 320, scale=noisy_spec_scale, round_center_bins=True, normalize=True)
def create_features(noisy, noisy_history, lpcs, gains, ltps, periods):
periods = periods.copy()
if pitch_hangover > 0:
periods = hangover(periods, num_frames=pitch_hangover)
periods[periods == 0] = no_pitch_value
clean_spectrum = 0.3 * log_spectrum_from_lpc(lpcs, fb=fb_clean_spec, n_fft=320)
if noisy_apply_dct:
noisy_cepstrum = np.repeat(
cepstrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
else:
noisy_cepstrum = np.repeat(
log_spectrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
log_gains = np.log(gains + 1e-9).reshape(-1, 1)
acorr, _ = calculate_acorr_window(noisy, 80, periods, noisy_history, radius=acorr_radius, add_double_lag_acorr=add_double_lag_acorr)
features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains), axis=-1, dtype=np.float32)
return features, periods.astype(np.int64)
return create_features
def load_inference_data(path,
no_pitch_value=256,
skip=92,
preemph=0.85,
acorr_radius=2,
pitch_hangover=8,
num_bands_clean_spec=64,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
add_double_lag_acorr=False,
**kwargs):
print(f"[load_inference_data]: ignoring keyword arguments {kwargs.keys()}...")
lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32).astype(np.float32).reshape(-1, 1)
num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32).reshape(-1, 1)
# load signal, add back delay and pre-emphasize
signal = np.fromfile(os.path.join(path, 'noisy.s16'), dtype=np.int16).astype(np.float32) / (2 ** 15)
signal = np.concatenate((np.zeros(skip, dtype=np.float32), signal), dtype=np.float32)
create_features = silk_feature_factory(no_pitch_value, acorr_radius, pitch_hangover, num_bands_clean_spec, num_bands_noisy_spec, noisy_spec_scale, noisy_apply_dct, add_double_lag_acorr)
num_frames = min((len(signal) // 320) * 4, len(lpcs))
signal = signal[: num_frames * 80]
lpcs = lpcs[: num_frames]
ltps = ltps[: num_frames]
gains = gains[: num_frames]
periods = periods[: num_frames]
num_bits = num_bits[: num_frames // 4]
num_bits_smooth = num_bits[: num_frames // 4]
numbits = np.repeat(np.concatenate((num_bits, num_bits_smooth), axis=-1, dtype=np.float32), 4, axis=0)
features, periods = create_features(signal, np.zeros(350, dtype=signal.dtype), lpcs, gains, ltps, periods)
if preemph > 0:
signal[1:] -= preemph * signal[:-1]
return torch.from_numpy(signal), torch.from_numpy(features), torch.from_numpy(periods), torch.from_numpy(numbits)

View File

@@ -0,0 +1,110 @@
import torch
@torch.no_grad()
def compute_optimal_scale(weight):
with torch.no_grad():
n_out, n_in = weight.shape
assert n_in % 4 == 0
if n_out % 8:
# add padding
pad = n_out - n_out % 8
weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
scale_max = weight_max_abs / 127
scale_sum = weight_max_sum / 129
scale = torch.maximum(scale_max, scale_sum)
return scale[:n_out]
@torch.no_grad()
def q_scaled_noise(module, weight):
if isinstance(module, torch.nn.Conv1d):
w = weight.permute(0, 2, 1).flatten(1)
noise = torch.rand_like(w) - 0.5
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
elif isinstance(module, torch.nn.ConvTranspose1d):
i, o, k = weight.shape
w = weight.permute(2, 1, 0).reshape(k * o, i)
noise = torch.rand_like(w) - 0.5
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(k, o, i).permute(2, 1, 0)
elif len(weight.shape) == 2:
noise = torch.rand_like(weight) - 0.5
scale = compute_optimal_scale(weight)
noise = noise * scale.unsqueeze(-1)
else:
raise ValueError('unknown quantization setting')
return noise
class SoftQuant:
name: str
def __init__(self, names: str, scale: float) -> None:
self.names = names
self.quantization_noise = None
self.scale = scale
def __call__(self, module, inputs, *args, before=True):
if not module.training: return
if before:
self.quantization_noise = dict()
for name in self.names:
weight = getattr(module, name)
if self.scale is None:
self.quantization_noise[name] = q_scaled_noise(module, weight)
else:
self.quantization_noise[name] = \
self.scale * weight.abs().max() * (torch.rand_like(weight) - 0.5)
with torch.no_grad():
weight.data[:] = weight + self.quantization_noise[name]
else:
for name in self.names:
weight = getattr(module, name)
with torch.no_grad():
weight.data[:] = weight - self.quantization_noise[name]
self.quantization_noise = None
def apply(module, names=['weight'], scale=None):
fn = SoftQuant(names, scale)
for name in names:
if not hasattr(module, name):
raise ValueError("")
fn_before = lambda *x : fn(*x, before=True)
fn_after = lambda *x : fn(*x, before=False)
setattr(fn_before, 'sqm', fn)
setattr(fn_after, 'sqm', fn)
module.register_forward_pre_hook(fn_before)
module.register_forward_hook(fn_after)
module
return fn
def soft_quant(module, names=['weight'], scale=None):
fn = SoftQuant.apply(module, names, scale)
return module
def remove_soft_quant(module, names=['weight']):
for k, hook in module._forward_pre_hooks.items():
if hasattr(hook, 'sqm'):
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
del module._forward_pre_hooks[k]
for k, hook in module._forward_hooks.items():
if hasattr(hook, 'sqm'):
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
del module._forward_hooks[k]
return module

View File

@@ -0,0 +1,210 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import math as m
import numpy as np
import scipy
import scipy.fftpack
import torch
def erb(f):
return 24.7 * (4.37 * f + 1)
def inv_erb(e):
return (e / 24.7 - 1) / 4.37
def bark(f):
return 6 * m.asinh(f/600)
def inv_bark(b):
return 600 * m.sinh(b / 6)
scale_dict = {
'bark': [bark, inv_bark],
'erb': [erb, inv_erb]
}
def gen_filterbank(N, Fs=16000, keep_size=False):
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
M = N + 1 if keep_size else N
out_freq = (np.arange(M, dtype='float32')/N*Fs/2)[:,None]
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
ERB_N = 24.7 + .108*in_freq
delta = np.abs(in_freq-out_freq)/ERB_N
center = (delta<.5).astype('float32')
R = -12*center*delta**2 + (1-center)*(3-12*delta)
RE = 10.**(R/10.)
norm = np.sum(RE, axis=1)
RE = RE/norm[:, np.newaxis]
return torch.from_numpy(RE)
def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
f0 = 0
num_bins = n_fft // 2 + 1
f1 = fs / n_fft * (num_bins - 1)
fstep = fs / n_fft
if scale == 'opus':
bins_5ms = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40]
fac = 1000 * n_fft / fs / 5
if num_bands != 18:
print("warning: requested Opus filter bank with num_bands != 18. Adjusting num_bands.")
num_bands = 18
center_bins = np.array([fac * bin for bin in bins_5ms])
else:
to_scale, from_scale = scale_dict[scale]
s0 = to_scale(f0)
s1 = to_scale(f1)
center_freqs = np.array([f0] + [from_scale(s0 + i * (s1 - s0) / (num_bands)) for i in range(1, num_bands - 1)] + [f1])
center_bins = (center_freqs - f0) / fstep
if round_center_bins:
center_bins = np.round(center_bins)
filter_bank = np.zeros((num_bands, num_bins))
band = 0
for bin in range(num_bins):
# update band index
if bin > center_bins[band + 1]:
band += 1
# calculate filter coefficients
frac = (center_bins[band + 1] - bin) / (center_bins[band + 1] - center_bins[band])
filter_bank[band][bin] = frac
filter_bank[band + 1][bin] = 1 - frac
if return_upper:
extend = n_fft - num_bins
filter_bank = np.concatenate((filter_bank, np.fliplr(filter_bank[:, 1:extend+1])), axis=1)
if normalize:
filter_bank = filter_bank / np.sum(filter_bank, axis=1).reshape(-1, 1)
return filter_bank
def compressed_log_spec(pspec):
lpspec = np.zeros_like(pspec)
num_bands = pspec.shape[-1]
log_max = -2
follow = -2
for i in range(num_bands):
tmp = np.log10(pspec[i] + 1e-9)
tmp = max(log_max, max(follow - 2.5, tmp))
lpspec[i] = tmp
log_max = max(log_max, tmp)
follow = max(follow - 2.5, tmp)
return lpspec
def log_spectrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False, power=1):
""" calculates cepstrum from SILK lpcs """
order = a.shape[-1]
assert order + 1 < n_fft
a = a * (gamma ** (1 + np.arange(order))).astype(np.float32)
x = np.zeros((*a.shape[:-1], n_fft ))
x[..., 0] = 1
x[..., 1:1 + order] = -a
X = np.fft.fft(x, axis=-1)
X = np.abs(X[..., :n_fft//2 + 1]) ** power
S = 1 / (X + eps)
if fb is None:
Sf = S
else:
Sf = np.matmul(S, fb.T)
if compress:
Sf = np.apply_along_axis(compressed_log_spec, -1, Sf)
else:
Sf = np.log(Sf + eps)
return Sf
def cepstrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False):
""" calculates cepstrum from SILK lpcs """
Sf = log_spectrum_from_lpc(a, fb, n_fft, eps, gamma, compress)
cepstrum = scipy.fftpack.dct(Sf, 2, norm='ortho')
return cepstrum
def log_spectrum(x, frame_size, fb=None, window=None, power=1):
""" calculate cepstrum on 50% overlapping frames """
assert(2*len(x)) % frame_size == 0
assert frame_size % 2 == 0
n = len(x)
num_even = n // frame_size
num_odd = (n - frame_size // 2) // frame_size
num_bins = frame_size // 2 + 1
x_even = x[:num_even * frame_size].reshape(-1, frame_size)
x_odd = x[frame_size//2 : frame_size//2 + frame_size * num_odd].reshape(-1, frame_size)
x_unfold = np.empty((x_even.size + x_odd.size), dtype=x.dtype).reshape((-1, frame_size))
x_unfold[::2, :] = x_even
x_unfold[1::2, :] = x_odd
if window is not None:
x_unfold *= window.reshape(1, -1)
X = np.abs(np.fft.fft(x_unfold, n=frame_size, axis=-1))[:, :num_bins] ** power
if fb is not None:
X = np.matmul(X, fb.T)
return np.log(X + 1e-9)
def cepstrum(x, frame_size, fb=None, window=None):
""" calculate cepstrum on 50% overlapping frames """
X = log_spectrum(x, frame_size, fb, window)
cepstrum = scipy.fftpack.dct(X, 2, norm='ortho')
return cepstrum

View File

@@ -0,0 +1,347 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
setup_dict = dict()
lace_setup = {
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
'model': {
'name': 'lace',
'args': [],
'kwargs': {
'comb_gain_limit_db': 10,
'cond_dim': 128,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'hidden_feature_dim': 96,
'kernel_size': 15,
'num_features': 93,
'numbits_embedding_dim': 8,
'numbits_range': [50, 650],
'partial_lookahead': True,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'skip': 91,
'softquant': True,
'sparsify': False,
'sparsification_density': 0.4,
'sparsification_schedule': [10000, 40000, 200]
}
},
'data': {
'frames_per_sample': 100,
'no_pitch_value': 7,
'preemph': 0.85,
'skip': 91,
'pitch_hangover': 8,
'acorr_radius': 2,
'num_bands_clean_spec': 64,
'num_bands_noisy_spec': 18,
'noisy_spec_scale': 'opus',
'pitch_hangover': 0,
},
'training': {
'batch_size': 256,
'lr': 5.e-4,
'lr_decay_factor': 2.5e-5,
'epochs': 50,
'loss': {
'w_l1': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_wsc': 0,
'w_xcorr': 0,
'w_sxcorr': 1,
'w_l2': 10,
'w_slm': 2
}
}
}
nolace_setup = {
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
'model': {
'name': 'nolace',
'args': [],
'kwargs': {
'avg_pool_k': 4,
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'hidden_feature_dim': 96,
'kernel_size': 15,
'num_features': 93,
'numbits_embedding_dim': 8,
'numbits_range': [50, 650],
'partial_lookahead': True,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'skip': 91,
'softquant': True,
'sparsify': False,
'sparsification_density': 0.4,
'sparsification_schedule': [10000, 40000, 200]
}
},
'data': {
'frames_per_sample': 100,
'no_pitch_value': 7,
'preemph': 0.85,
'skip': 91,
'pitch_hangover': 8,
'acorr_radius': 2,
'num_bands_clean_spec': 64,
'num_bands_noisy_spec': 18,
'noisy_spec_scale': 'opus',
'pitch_hangover': 0,
},
'training': {
'batch_size': 256,
'lr': 5.e-4,
'lr_decay_factor': 2.5e-5,
'epochs': 50,
'loss': {
'w_l1': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_wsc': 0,
'w_xcorr': 0,
'w_sxcorr': 1,
'w_l2': 10,
'w_slm': 2
}
}
}
nolace_setup_adv = {
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
'model': {
'name': 'nolace',
'args': [],
'kwargs': {
'avg_pool_k': 4,
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'hidden_feature_dim': 96,
'kernel_size': 15,
'num_features': 93,
'numbits_embedding_dim': 8,
'numbits_range': [50, 650],
'partial_lookahead': True,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'skip': 91,
'softquant': True,
'sparsify': False,
'sparsification_density': 0.4,
'sparsification_schedule': [0, 0, 200]
}
},
'data': {
'frames_per_sample': 100,
'no_pitch_value': 7,
'preemph': 0.85,
'skip': 91,
'pitch_hangover': 8,
'acorr_radius': 2,
'num_bands_clean_spec': 64,
'num_bands_noisy_spec': 18,
'noisy_spec_scale': 'opus',
'pitch_hangover': 0,
},
'discriminator': {
'args': [],
'kwargs': {
'architecture': 'free',
'design': 'f_down',
'fft_sizes_16k': [
64,
128,
256,
512,
1024,
2048,
],
'freq_roi': [0, 7400],
'fs': 16000,
'max_channels': 256,
'noise_gain': 0.0,
},
'name': 'fdmresdisc',
},
'training': {
'adv_target': 'target_orig',
'batch_size': 64,
'epochs': 50,
'gen_lr_reduction': 1,
'lambda_feat': 1.0,
'lambda_reg': 0.6,
'loss': {
'w_l1': 0,
'w_l2': 10,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_slm': 20,
'w_sxcorr': 1,
'w_wsc': 0,
'w_xcorr': 0,
},
'lr': 0.0001,
'lr_decay_factor': 2.5e-09,
}
}
lavoce_setup = {
'data': {
'frames_per_sample': 100,
'target': 'signal'
},
'dataset': '/local/datasets/lpcnet_large/training',
'model': {
'args': [],
'kwargs': {
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'kernel_size': 15,
'num_features': 19,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'pulses': True
},
'name': 'lavoce'
},
'training': {
'batch_size': 256,
'epochs': 50,
'loss': {
'w_l1': 0,
'w_l2': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_slm': 2,
'w_sxcorr': 1,
'w_wsc': 0,
'w_xcorr': 0
},
'lr': 0.0005,
'lr_decay_factor': 2.5e-05
},
'validation_dataset': '/local/datasets/lpcnet_large/validation'
}
lavoce_setup_adv = {
'data': {
'frames_per_sample': 100,
'target': 'signal'
},
'dataset': '/local/datasets/lpcnet_large/training',
'discriminator': {
'args': [],
'kwargs': {
'architecture': 'free',
'design': 'f_down',
'fft_sizes_16k': [
64,
128,
256,
512,
1024,
2048,
],
'freq_roi': [0, 7400],
'fs': 16000,
'max_channels': 256,
'noise_gain': 0.0,
},
'name': 'fdmresdisc',
},
'model': {
'args': [],
'kwargs': {
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'kernel_size': 15,
'num_features': 19,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'pulses': True
},
'name': 'lavoce'
},
'training': {
'batch_size': 64,
'epochs': 50,
'gen_lr_reduction': 1,
'lambda_feat': 1.0,
'lambda_reg': 0.6,
'loss': {
'w_l1': 0,
'w_l2': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_slm': 2,
'w_sxcorr': 1,
'w_wsc': 0,
'w_xcorr': 0
},
'lr': 0.0001,
'lr_decay_factor': 2.5e-09
},
}
setup_dict = {
'lace': lace_setup,
'nolace': nolace_setup,
'nolace_adv': nolace_setup_adv,
'lavoce': lavoce_setup,
'lavoce_adv': lavoce_setup_adv
}