add some code
This commit is contained in:
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jean-Marc Valin */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# x is (batch, nb_in_channels, nb_frames*frame_size)
|
||||
# kernels is (batch, nb_out_channels, nb_in_channels, nb_frames, coeffs)
|
||||
def adaconv_kernel(x, kernels, half_window, fft_size=256):
|
||||
device=x.device
|
||||
overlap_size=half_window.size(-1)
|
||||
nb_frames=kernels.size(3)
|
||||
nb_batches=kernels.size(0)
|
||||
nb_out_channels=kernels.size(1)
|
||||
nb_in_channels=kernels.size(2)
|
||||
kernel_size = kernels.size(-1)
|
||||
x = x.reshape(nb_batches, 1, nb_in_channels, nb_frames, -1)
|
||||
frame_size = x.size(-1)
|
||||
# build window: [zeros, rising window, ones, falling window, zeros]
|
||||
window = torch.cat(
|
||||
[
|
||||
torch.zeros(frame_size, device=device),
|
||||
half_window,
|
||||
torch.ones(frame_size - overlap_size, device=device),
|
||||
1 - half_window,
|
||||
torch.zeros(fft_size - 2 * frame_size - overlap_size,device=device)
|
||||
])
|
||||
x_prev = torch.cat([torch.zeros_like(x[:, :, :, :1, :]), x[:, :, :, :-1, :]], dim=-2)
|
||||
x_next = torch.cat([x[:, :, :, 1:, :overlap_size], torch.zeros_like(x[:, :, :, -1:, :overlap_size])], dim=-2)
|
||||
x_padded = torch.cat([x_prev, x, x_next, torch.zeros(nb_batches, 1, nb_in_channels, nb_frames, fft_size - 2 * frame_size - overlap_size, device=device)], -1)
|
||||
k_padded = torch.cat([torch.flip(kernels, [-1]), torch.zeros(nb_batches, nb_out_channels, nb_in_channels, nb_frames, fft_size-kernel_size, device=device)], dim=-1)
|
||||
|
||||
# compute convolution
|
||||
X = torch.fft.rfft(x_padded, dim=-1)
|
||||
K = torch.fft.rfft(k_padded, dim=-1)
|
||||
|
||||
out = torch.fft.irfft(X * K, dim=-1)
|
||||
# combine in channels
|
||||
out = torch.sum(out, dim=2)
|
||||
# apply the cross-fading
|
||||
out = window.reshape(1, 1, 1, -1)*out
|
||||
crossfaded = out[:,:,:,frame_size:2*frame_size] + torch.cat([torch.zeros(nb_batches, nb_out_channels, 1, frame_size, device=device), out[:, :, :-1, 2*frame_size:3*frame_size]], dim=-2)
|
||||
|
||||
return crossfaded.reshape(nb_batches, nb_out_channels, -1)
|
||||
@@ -0,0 +1,8 @@
|
||||
|
||||
|
||||
def _conv1d_flop_count(layer, rate):
|
||||
return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0]
|
||||
|
||||
|
||||
def _dense_flop_count(layer, rate):
|
||||
return 2 * ((layer.in_features + 1) * layer.out_features * rate )
|
||||
@@ -0,0 +1,205 @@
|
||||
""" module for inspecting models during inference """
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation as animation
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
|
||||
_state = dict()
|
||||
_folder = 'endoscopy'
|
||||
|
||||
def get_gru_gates(gru, input, state):
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
|
||||
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
|
||||
|
||||
# reset gate
|
||||
start, stop = 0 * hidden_size, 1 * hidden_size
|
||||
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# update gate
|
||||
start, stop = 1 * hidden_size, 2 * hidden_size
|
||||
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# new gate
|
||||
start, stop = 2 * hidden_size, 3 * hidden_size
|
||||
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
|
||||
|
||||
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
|
||||
|
||||
|
||||
def init(folder='endoscopy'):
|
||||
""" sets up output folder for endoscopy data """
|
||||
|
||||
global _folder
|
||||
_folder = folder
|
||||
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
else:
|
||||
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
|
||||
|
||||
def write_data(key, data, fs):
|
||||
""" appends data to previous data written under key """
|
||||
|
||||
global _state
|
||||
|
||||
# convert to numpy if torch.Tensor is given
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.detach().numpy()
|
||||
|
||||
if not key in _state:
|
||||
_state[key] = {
|
||||
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
|
||||
'fs' : fs,
|
||||
'dim' : tuple(data.shape),
|
||||
'dtype' : str(data.dtype)
|
||||
}
|
||||
|
||||
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
|
||||
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
|
||||
else:
|
||||
if _state[key]['fs'] != fs:
|
||||
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
|
||||
if _state[key]['dtype'] != str(data.dtype):
|
||||
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
|
||||
if _state[key]['dim'] != tuple(data.shape):
|
||||
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
|
||||
|
||||
_state[key]['fid'].write(data.tobytes())
|
||||
|
||||
def close(folder='endoscopy'):
|
||||
""" clean up """
|
||||
for key in _state.keys():
|
||||
_state[key]['fid'].close()
|
||||
|
||||
|
||||
def read_data(folder='endoscopy'):
|
||||
""" retrieves written data as numpy arrays """
|
||||
|
||||
|
||||
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
|
||||
|
||||
return_dict = dict()
|
||||
|
||||
for key in keys:
|
||||
with open(os.path.join(folder, key + '.yml'), 'r') as f:
|
||||
value = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
|
||||
data = np.frombuffer(f.read(), dtype=value['dtype'])
|
||||
|
||||
value['data'] = data.reshape((-1,) + value['dim'])
|
||||
|
||||
return_dict[key] = value
|
||||
|
||||
return return_dict
|
||||
|
||||
def get_best_reshape(shape, target_ratio=1):
|
||||
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
|
||||
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return (1,)
|
||||
|
||||
num_columns = int((pixel_count / target_ratio)**.5)
|
||||
|
||||
while (pixel_count % num_columns):
|
||||
num_columns -= 1
|
||||
|
||||
num_rows = pixel_count // num_columns
|
||||
|
||||
return (num_rows, num_columns)
|
||||
|
||||
def get_type_and_shape(shape):
|
||||
|
||||
# can happen if data is one dimensional
|
||||
if len(shape) == 0:
|
||||
shape = (1,)
|
||||
|
||||
# calculate pixel count
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return 'plot', (1, )
|
||||
|
||||
# stay with shape if already 2-dimensional
|
||||
if len(shape) == 2:
|
||||
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
|
||||
return 'image', shape
|
||||
|
||||
return 'image', get_best_reshape(shape)
|
||||
|
||||
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
|
||||
|
||||
# determine plot setup
|
||||
num_keys = len(data.keys())
|
||||
|
||||
num_rows = int((num_keys * 3/4) ** .5)
|
||||
|
||||
num_cols = (num_keys + num_rows - 1) // num_rows
|
||||
|
||||
fig, axs = plt.subplots(num_rows, num_cols)
|
||||
fig.set_size_inches(num_cols * 5, num_rows * 5)
|
||||
|
||||
display = dict()
|
||||
|
||||
fs_max = max([val['fs'] for val in data.values()])
|
||||
|
||||
num_samples = max([val['data'].shape[0] for val in data.values()])
|
||||
|
||||
keys = sorted(data.keys())
|
||||
|
||||
# inspect data
|
||||
for i, key in enumerate(keys):
|
||||
axs[i // num_cols, i % num_cols].title.set_text(key)
|
||||
|
||||
display[key] = dict()
|
||||
|
||||
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
|
||||
display[key]['down_factor'] = data[key]['fs'] / fs_max
|
||||
|
||||
start_index = max(start_index, half_signal_window_length)
|
||||
while stop_index < 0:
|
||||
stop_index += num_samples
|
||||
|
||||
stop_index = min(stop_index, num_samples - half_signal_window_length)
|
||||
|
||||
# actual plotting
|
||||
frames = []
|
||||
for index in range(start_index, stop_index):
|
||||
ims = []
|
||||
for i, key in enumerate(keys):
|
||||
feature_index = int(round(index * display[key]['down_factor']))
|
||||
|
||||
if display[key]['type'] == 'plot':
|
||||
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
|
||||
|
||||
elif display[key]['type'] == 'image':
|
||||
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
|
||||
|
||||
frames.append(ims)
|
||||
|
||||
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
|
||||
|
||||
if not filename.endswith('.mp4'):
|
||||
filename += '.mp4'
|
||||
|
||||
ani.save(filename)
|
||||
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FIR(nn.Module):
|
||||
def __init__(self, numtaps, bands, desired, fs=2):
|
||||
super().__init__()
|
||||
|
||||
if numtaps % 2 == 0:
|
||||
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
|
||||
numtaps += 1
|
||||
|
||||
a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
|
||||
|
||||
self.weight = torch.from_numpy(a.astype(np.float32))
|
||||
|
||||
def forward(self, x):
|
||||
num_channels = x.size(1)
|
||||
|
||||
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
|
||||
|
||||
y = F.conv1d(x, weight, groups=num_channels)
|
||||
|
||||
return y
|
||||
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.endoscopy import write_data
|
||||
from utils.softquant import soft_quant
|
||||
|
||||
class LimitedAdaptiveComb1d(nn.Module):
|
||||
COUNTER = 1
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
feature_dim,
|
||||
frame_size=160,
|
||||
overlap_size=40,
|
||||
padding=None,
|
||||
max_lag=256,
|
||||
name=None,
|
||||
gain_limit_db=10,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
norm_p=2,
|
||||
softquant=False,
|
||||
apply_weight_norm=False,
|
||||
**kwargs):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
feature_dim : int
|
||||
dimension of features from which kernels, biases and gains are computed
|
||||
|
||||
frame_size : int, optional
|
||||
frame size, defaults to 160
|
||||
|
||||
overlap_size : int, optional
|
||||
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40
|
||||
|
||||
use_bias : bool, optional
|
||||
if true, biases will be added to output channels. Defaults to True
|
||||
|
||||
padding : List[int, int], optional
|
||||
left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2]
|
||||
|
||||
max_lag : int, optional
|
||||
maximal pitch lag, defaults to 256
|
||||
|
||||
have_a0 : bool, optional
|
||||
If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False
|
||||
|
||||
name: str or None, optional
|
||||
specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d
|
||||
|
||||
"""
|
||||
|
||||
super(LimitedAdaptiveComb1d, self).__init__()
|
||||
|
||||
self.in_channels = 1
|
||||
self.out_channels = 1
|
||||
self.feature_dim = feature_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.frame_size = frame_size
|
||||
self.overlap_size = overlap_size
|
||||
self.max_lag = max_lag
|
||||
self.limit_db = gain_limit_db
|
||||
self.norm_p = norm_p
|
||||
|
||||
if name is None:
|
||||
self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER)
|
||||
LimitedAdaptiveComb1d.COUNTER += 1
|
||||
else:
|
||||
self.name = name
|
||||
|
||||
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# network for generating convolution weights
|
||||
self.conv_kernel = norm(nn.Linear(feature_dim, kernel_size))
|
||||
|
||||
if softquant:
|
||||
self.conv_kernel = soft_quant(self.conv_kernel)
|
||||
|
||||
|
||||
# comb filter gain
|
||||
self.filter_gain = norm(nn.Linear(feature_dim, 1))
|
||||
self.log_gain_limit = gain_limit_db * 0.11512925464970229
|
||||
with torch.no_grad():
|
||||
self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
|
||||
|
||||
self.global_filter_gain = norm(nn.Linear(feature_dim, 1))
|
||||
log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
|
||||
self.filter_gain_a = (log_max - log_min) / 2
|
||||
self.filter_gain_b = (log_max + log_min) / 2
|
||||
|
||||
if type(padding) == type(None):
|
||||
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
|
||||
|
||||
def forward(self, x, features, lags, debug=False):
|
||||
""" adaptive 1d convolution
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.tensor
|
||||
input signal of shape (batch_size, in_channels, num_samples)
|
||||
|
||||
feathres : torch.tensor
|
||||
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
||||
|
||||
lags: torch.LongTensor
|
||||
frame-wise lags for comb-filtering
|
||||
|
||||
"""
|
||||
|
||||
batch_size = x.size(0)
|
||||
num_frames = features.size(1)
|
||||
num_samples = x.size(2)
|
||||
frame_size = self.frame_size
|
||||
overlap_size = self.overlap_size
|
||||
kernel_size = self.kernel_size
|
||||
win1 = torch.flip(self.overlap_win, [0])
|
||||
win2 = self.overlap_win
|
||||
|
||||
if num_samples // self.frame_size != num_frames:
|
||||
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
|
||||
|
||||
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
|
||||
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
|
||||
|
||||
conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
|
||||
# calculate gains
|
||||
global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
|
||||
|
||||
if debug and batch_size == 1:
|
||||
key = self.name + "_gains"
|
||||
write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_kernels"
|
||||
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_lags"
|
||||
write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_global_conv_gains"
|
||||
write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
|
||||
|
||||
# frame-wise convolution with overlap-add
|
||||
output_frames = []
|
||||
overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
|
||||
x = F.pad(x, self.padding)
|
||||
x = F.pad(x, [self.max_lag, self.overlap_size])
|
||||
|
||||
idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1)
|
||||
idx = torch.repeat_interleave(idx, batch_size, 0)
|
||||
idx = torch.repeat_interleave(idx, self.in_channels, 1)
|
||||
|
||||
|
||||
for i in range(num_frames):
|
||||
|
||||
cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1)
|
||||
xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1))
|
||||
|
||||
new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
|
||||
|
||||
offset = self.max_lag + self.padding[0]
|
||||
new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
|
||||
|
||||
# overlapping part
|
||||
output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
|
||||
|
||||
# non-overlapping part
|
||||
output_frames.append(new_chunk[:, :, overlap_size : frame_size])
|
||||
|
||||
# mem for next frame
|
||||
overlap_mem = new_chunk[:, :, frame_size :]
|
||||
|
||||
# concatenate chunks
|
||||
output = torch.cat(output_frames, dim=-1)
|
||||
|
||||
return output
|
||||
|
||||
def flop_count(self, rate):
|
||||
frame_rate = rate / self.frame_size
|
||||
overlap = self.overlap_size
|
||||
overhead = overlap / self.frame_size
|
||||
|
||||
count = 0
|
||||
|
||||
# kernel computation and filtering
|
||||
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
|
||||
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
|
||||
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
|
||||
|
||||
# a0 computation
|
||||
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
|
||||
|
||||
# windowing
|
||||
count += overlap * frame_rate * 3 * self.out_channels
|
||||
|
||||
return count
|
||||
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.endoscopy import write_data
|
||||
|
||||
from utils.ada_conv import adaconv_kernel
|
||||
from utils.softquant import soft_quant
|
||||
|
||||
class LimitedAdaptiveConv1d(nn.Module):
|
||||
COUNTER = 1
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
feature_dim,
|
||||
frame_size=160,
|
||||
overlap_size=40,
|
||||
padding=None,
|
||||
name=None,
|
||||
gain_limits_db=[-6, 6],
|
||||
shape_gain_db=0,
|
||||
norm_p=2,
|
||||
softquant=False,
|
||||
apply_weight_norm=False,
|
||||
**kwargs):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
in_channels : int
|
||||
number of input channels
|
||||
|
||||
out_channels : int
|
||||
number of output channels
|
||||
|
||||
feature_dim : int
|
||||
dimension of features from which kernels, biases and gains are computed
|
||||
|
||||
frame_size : int
|
||||
frame size
|
||||
|
||||
overlap_size : int
|
||||
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame
|
||||
|
||||
use_bias : bool
|
||||
if true, biases will be added to output channels
|
||||
|
||||
|
||||
padding : List[int, int]
|
||||
|
||||
"""
|
||||
|
||||
super(LimitedAdaptiveConv1d, self).__init__()
|
||||
|
||||
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.feature_dim = feature_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.frame_size = frame_size
|
||||
self.overlap_size = overlap_size
|
||||
self.gain_limits_db = gain_limits_db
|
||||
self.shape_gain_db = shape_gain_db
|
||||
self.norm_p = norm_p
|
||||
|
||||
if name is None:
|
||||
self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER)
|
||||
LimitedAdaptiveConv1d.COUNTER += 1
|
||||
else:
|
||||
self.name = name
|
||||
|
||||
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# network for generating convolution weights
|
||||
self.conv_kernel = norm(nn.Linear(feature_dim, in_channels * out_channels * kernel_size))
|
||||
if softquant:
|
||||
self.conv_kernel = soft_quant(self.conv_kernel)
|
||||
|
||||
self.shape_gain = min(1, 10**(shape_gain_db / 20))
|
||||
|
||||
self.filter_gain = norm(nn.Linear(feature_dim, out_channels))
|
||||
log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
|
||||
self.filter_gain_a = (log_max - log_min) / 2
|
||||
self.filter_gain_b = (log_max + log_min) / 2
|
||||
|
||||
if type(padding) == type(None):
|
||||
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
|
||||
|
||||
|
||||
def flop_count(self, rate):
|
||||
frame_rate = rate / self.frame_size
|
||||
overlap = self.overlap_size
|
||||
overhead = overlap / self.frame_size
|
||||
|
||||
count = 0
|
||||
|
||||
# kernel computation and filtering
|
||||
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
|
||||
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
|
||||
|
||||
# gain computation
|
||||
|
||||
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
|
||||
|
||||
# windowing
|
||||
count += 3 * overlap * frame_rate * self.out_channels
|
||||
|
||||
return count
|
||||
|
||||
def forward(self, x, features, debug=False):
|
||||
""" adaptive 1d convolution
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.tensor
|
||||
input signal of shape (batch_size, in_channels, num_samples)
|
||||
|
||||
feathres : torch.tensor
|
||||
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
||||
|
||||
"""
|
||||
|
||||
batch_size = x.size(0)
|
||||
num_frames = features.size(1)
|
||||
num_samples = x.size(2)
|
||||
frame_size = self.frame_size
|
||||
overlap_size = self.overlap_size
|
||||
kernel_size = self.kernel_size
|
||||
win1 = torch.flip(self.overlap_win, [0])
|
||||
win2 = self.overlap_win
|
||||
|
||||
if num_samples // self.frame_size != num_frames:
|
||||
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
|
||||
|
||||
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
|
||||
|
||||
# normalize kernels (TODO: switch to L1 and normalize over kernel and input channel dimension)
|
||||
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=[-2, -1], keepdim=True))
|
||||
|
||||
# limit shape
|
||||
id_kernels = torch.zeros_like(conv_kernels)
|
||||
id_kernels[..., self.padding[1]] = 1
|
||||
|
||||
conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels
|
||||
|
||||
# calculate gains
|
||||
conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features)) + self.filter_gain_b)
|
||||
if debug and batch_size == 1:
|
||||
key = self.name + "_gains"
|
||||
write_data(key, conv_gains.permute(0, 2, 1).detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_kernels"
|
||||
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
|
||||
|
||||
conv_kernels = conv_kernels * conv_gains.view(batch_size, num_frames, self.out_channels, 1, 1)
|
||||
|
||||
conv_kernels = conv_kernels.permute(0, 2, 3, 1, 4)
|
||||
|
||||
output = adaconv_kernel(x, conv_kernels, win1, fft_size=256)
|
||||
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
class NoiseShaper(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
feature_dim,
|
||||
frame_size=160
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
feature_dim : int
|
||||
dimension of input features
|
||||
|
||||
frame_size : int
|
||||
frame size
|
||||
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.frame_size = frame_size
|
||||
|
||||
# feature transform
|
||||
self.feature_alpha1 = nn.Conv1d(self.feature_dim, frame_size, 2)
|
||||
self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2)
|
||||
|
||||
|
||||
def flop_count(self, rate):
|
||||
|
||||
frame_rate = rate / self.frame_size
|
||||
|
||||
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
|
||||
|
||||
return shape_flops
|
||||
|
||||
|
||||
def forward(self, features):
|
||||
""" creates temporally shaped noise
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
features : torch.tensor
|
||||
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
||||
|
||||
"""
|
||||
|
||||
batch_size = features.size(0)
|
||||
num_frames = features.size(1)
|
||||
frame_size = self.frame_size
|
||||
num_samples = num_frames * frame_size
|
||||
|
||||
# feature path
|
||||
f = F.pad(features.permute(0, 2, 1), [1, 0])
|
||||
alpha = F.leaky_relu(self.feature_alpha1(f), 0.2)
|
||||
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
|
||||
alpha = alpha.permute(0, 2, 1)
|
||||
|
||||
# signal generation
|
||||
y = torch.randn((batch_size, num_frames, frame_size), dtype=features.dtype, device=features.device)
|
||||
y = alpha * y
|
||||
|
||||
return y.reshape(batch_size, 1, num_samples)
|
||||
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PitchAutoCorrelator(nn.Module):
|
||||
def __init__(self,
|
||||
frame_size=80,
|
||||
pitch_min=32,
|
||||
pitch_max=300,
|
||||
radius=2):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.frame_size = frame_size
|
||||
self.pitch_min = pitch_min
|
||||
self.pitch_max = pitch_max
|
||||
self.radius = radius
|
||||
|
||||
|
||||
def forward(self, x, periods):
|
||||
# x of shape (batch_size, channels, num_samples)
|
||||
# periods of shape (batch_size, num_frames)
|
||||
|
||||
num_frames = periods.size(1)
|
||||
batch_size = periods.size(0)
|
||||
num_samples = self.frame_size * num_frames
|
||||
channels = x.size(1)
|
||||
|
||||
assert num_samples == x.size(-1)
|
||||
|
||||
range = torch.arange(-self.radius, self.radius + 1, device=x.device)
|
||||
idx = torch.arange(self.frame_size * num_frames, device=x.device)
|
||||
p_up = torch.repeat_interleave(periods, self.frame_size, 1)
|
||||
lookup = idx + self.pitch_max - p_up
|
||||
lookup = lookup.unsqueeze(-1) + range
|
||||
lookup = lookup.unsqueeze(1)
|
||||
|
||||
# padding
|
||||
x_pad = F.pad(x, [self.pitch_max, 0])
|
||||
x_ext = torch.repeat_interleave(x_pad.unsqueeze(-1), 2 * self.radius + 1, -1)
|
||||
|
||||
# framing
|
||||
x_select = torch.gather(x_ext, 2, lookup)
|
||||
x_frames = x_pad[..., self.pitch_max : ].reshape(batch_size, channels, num_frames, self.frame_size, 1)
|
||||
lag_frames = x_select.reshape(batch_size, 1, num_frames, self.frame_size, -1)
|
||||
|
||||
# calculate auto-correlation
|
||||
dotp = torch.sum(x_frames * lag_frames, dim=-2)
|
||||
frame_nrg = torch.sum(x_frames * x_frames, dim=-2)
|
||||
lag_frame_nrg = torch.sum(lag_frames * lag_frames, dim=-2)
|
||||
|
||||
acorr = dotp / torch.sqrt(frame_nrg * lag_frame_nrg + 1e-9)
|
||||
|
||||
return acorr
|
||||
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
frac_fir = np.array(
|
||||
[
|
||||
[189, -600, 617, 30567, 2996, -1375, 425, -46],
|
||||
[117, -159, -1070, 29704, 5784, -2143, 611, -71],
|
||||
[52, 221, -2392, 28276, 8798, -2865, 773, -91],
|
||||
[-4, 529, -3350, 26341, 11950, -3487, 896, -103],
|
||||
[-48, 758, -3956, 23973, 15143, -3957, 967, -107],
|
||||
[-80, 905, -4235, 21254, 18278, -4222, 972, -99],
|
||||
[-99, 972, -4222, 18278, 21254, -4235, 905, -80],
|
||||
[-107, 967, -3957, 15143, 23973, -3956, 758, -48],
|
||||
[-103, 896, -3487, 11950, 26341, -3350, 529, -4],
|
||||
[-91, 773, -2865, 8798, 28276, -2392, 221, 52],
|
||||
[-71, 611, -2143, 5784, 29704, -1070, -159, 117],
|
||||
[-46, 425, -1375, 2996, 30567, 617, -600, 189]
|
||||
],
|
||||
dtype=np.float32
|
||||
) / 2**15
|
||||
|
||||
|
||||
hq_2x_up_c_even = [x / 2**16 for x in [1746, 14986, 39083 - 65536]]
|
||||
hq_2x_up_c_odd = [x / 2**16 for x in [6854, 25769, 55542 - 65536]]
|
||||
|
||||
|
||||
def get_impz(coeffs, n):
|
||||
s = 3*[0]
|
||||
y = np.zeros(n)
|
||||
x = 1
|
||||
|
||||
for i in range(n):
|
||||
Y = x - s[0]
|
||||
X = Y * coeffs[0]
|
||||
tmp1 = s[0] + X
|
||||
s[0] = x + X
|
||||
|
||||
Y = tmp1 - s[1]
|
||||
X = Y * coeffs[1]
|
||||
tmp2 = s[1] + X
|
||||
s[1] = tmp1 + X
|
||||
|
||||
Y = tmp2 - s[2]
|
||||
X = Y * (1 + coeffs[2])
|
||||
tmp3 = s[2] + X
|
||||
s[2] = tmp2 + X
|
||||
|
||||
y[i] = tmp3
|
||||
x = 0
|
||||
|
||||
return y
|
||||
|
||||
|
||||
|
||||
class SilkUpsampler(nn.Module):
|
||||
SUPPORTED_TARGET_RATES = {24000, 48000}
|
||||
SUPPORTED_SOURCE_RATES = {16000}
|
||||
def __init__(self,
|
||||
fs_in=16000,
|
||||
fs_out=48000):
|
||||
|
||||
super().__init__()
|
||||
self.fs_in = fs_in
|
||||
self.fs_out = fs_out
|
||||
|
||||
if fs_in not in self.SUPPORTED_SOURCE_RATES:
|
||||
raise ValueError(f'SilkUpsampler currently only supports upsampling from {self.SUPPORTED_SOURCE_RATES} Hz')
|
||||
|
||||
|
||||
if fs_out not in self.SUPPORTED_TARGET_RATES:
|
||||
raise ValueError(f'SilkUpsampler currently only supports upsampling to {self.SUPPORTED_TARGET_RATES} Hz')
|
||||
|
||||
|
||||
# hq 2x upsampler as FIR approximation
|
||||
hq_2x_up_even = get_impz(hq_2x_up_c_even, 128)[::-1].copy()
|
||||
hq_2x_up_odd = get_impz(hq_2x_up_c_odd , 128)[::-1].copy()
|
||||
|
||||
self.hq_2x_up_even = nn.Parameter(torch.from_numpy(hq_2x_up_even).float().view(1, 1, -1), requires_grad=False)
|
||||
self.hq_2x_up_odd = nn.Parameter(torch.from_numpy(hq_2x_up_odd ).float().view(1, 1, -1), requires_grad=False)
|
||||
self.hq_2x_up_padding = [127, 0]
|
||||
|
||||
# interpolation filters
|
||||
frac_01_24 = frac_fir[0]
|
||||
frac_17_24 = frac_fir[8]
|
||||
frac_09_24 = frac_fir[4]
|
||||
|
||||
self.frac_01_24 = nn.Parameter(torch.from_numpy(frac_01_24).view(1, 1, -1), requires_grad=False)
|
||||
self.frac_17_24 = nn.Parameter(torch.from_numpy(frac_17_24).view(1, 1, -1), requires_grad=False)
|
||||
self.frac_09_24 = nn.Parameter(torch.from_numpy(frac_09_24).view(1, 1, -1), requires_grad=False)
|
||||
|
||||
self.stride = 1 if fs_out == 48000 else 2
|
||||
|
||||
def hq_2x_up(self, x):
|
||||
|
||||
num_channels = x.size(1)
|
||||
|
||||
weight_even = torch.repeat_interleave(self.hq_2x_up_even, num_channels, 0)
|
||||
weight_odd = torch.repeat_interleave(self.hq_2x_up_odd , num_channels, 0)
|
||||
|
||||
x_pad = F.pad(x, self.hq_2x_up_padding)
|
||||
y_even = F.conv1d(x_pad, weight_even, groups=num_channels)
|
||||
y_odd = F.conv1d(x_pad, weight_odd , groups=num_channels)
|
||||
|
||||
y = torch.cat((y_even.unsqueeze(-1), y_odd.unsqueeze(-1)), dim=-1).flatten(2)
|
||||
|
||||
return y
|
||||
|
||||
def interpolate_3_2(self, x):
|
||||
|
||||
num_channels = x.size(1)
|
||||
|
||||
weight_01_24 = torch.repeat_interleave(self.frac_01_24, num_channels, 0)
|
||||
weight_17_24 = torch.repeat_interleave(self.frac_17_24, num_channels, 0)
|
||||
weight_09_24 = torch.repeat_interleave(self.frac_09_24, num_channels, 0)
|
||||
|
||||
x_pad = F.pad(x, [8, 0])
|
||||
y_01_24 = F.conv1d(x_pad, weight_01_24, stride=2, groups=num_channels)
|
||||
y_17_24 = F.conv1d(x_pad, weight_17_24, stride=2, groups=num_channels)
|
||||
y_09_24_sh1 = F.conv1d(torch.roll(x_pad, -1, -1), weight_09_24, stride=2, groups=num_channels)
|
||||
|
||||
|
||||
y = torch.cat(
|
||||
(y_01_24.unsqueeze(-1), y_17_24.unsqueeze(-1), y_09_24_sh1.unsqueeze(-1)),
|
||||
dim=-1).flatten(2)
|
||||
|
||||
return y[..., :-3]
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
y_2x = self.hq_2x_up(x)
|
||||
y_3x = self.interpolate_3_2(y_2x)
|
||||
|
||||
return y_3x[:, :, ::self.stride]
|
||||
@@ -0,0 +1,145 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
from utils.softquant import soft_quant
|
||||
|
||||
class TDShaper(nn.Module):
|
||||
COUNTER = 1
|
||||
|
||||
def __init__(self,
|
||||
feature_dim,
|
||||
frame_size=160,
|
||||
avg_pool_k=4,
|
||||
innovate=False,
|
||||
pool_after=False,
|
||||
softquant=False,
|
||||
apply_weight_norm=False
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
|
||||
feature_dim : int
|
||||
dimension of input features
|
||||
|
||||
frame_size : int
|
||||
frame size
|
||||
|
||||
avg_pool_k : int, optional
|
||||
kernel size and stride for avg pooling
|
||||
|
||||
padding : List[int, int]
|
||||
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.frame_size = frame_size
|
||||
self.avg_pool_k = avg_pool_k
|
||||
self.innovate = innovate
|
||||
self.pool_after = pool_after
|
||||
|
||||
assert frame_size % avg_pool_k == 0
|
||||
self.env_dim = frame_size // avg_pool_k + 1
|
||||
|
||||
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# feature transform
|
||||
self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
|
||||
self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))
|
||||
self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2))
|
||||
|
||||
if softquant:
|
||||
self.feature_alpha1_f = soft_quant(self.feature_alpha1_f)
|
||||
|
||||
if self.innovate:
|
||||
self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
|
||||
self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
|
||||
|
||||
self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2))
|
||||
self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2))
|
||||
|
||||
|
||||
def flop_count(self, rate):
|
||||
|
||||
frame_rate = rate / self.frame_size
|
||||
|
||||
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
|
||||
|
||||
if self.innovate:
|
||||
inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
|
||||
else:
|
||||
inno_flops = 0
|
||||
|
||||
return shape_flops + inno_flops
|
||||
|
||||
def envelope_transform(self, x):
|
||||
|
||||
x = torch.abs(x)
|
||||
if self.pool_after:
|
||||
x = torch.log(x + .5**16)
|
||||
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
|
||||
else:
|
||||
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
|
||||
x = torch.log(x + .5**16)
|
||||
|
||||
x = x.reshape(x.size(0), -1, self.env_dim - 1)
|
||||
avg_x = torch.mean(x, -1, keepdim=True)
|
||||
|
||||
x = torch.cat((x - avg_x, avg_x), dim=-1)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, features, debug=False):
|
||||
""" innovate signal parts with temporal shaping
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.tensor
|
||||
input signal of shape (batch_size, 1, num_samples)
|
||||
|
||||
features : torch.tensor
|
||||
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
||||
|
||||
"""
|
||||
|
||||
batch_size = x.size(0)
|
||||
num_frames = features.size(1)
|
||||
num_samples = x.size(2)
|
||||
frame_size = self.frame_size
|
||||
|
||||
# generate temporal envelope
|
||||
tenv = self.envelope_transform(x)
|
||||
|
||||
# feature path
|
||||
f = F.pad(features.permute(0, 2, 1), [1, 0])
|
||||
t = F.pad(tenv.permute(0, 2, 1), [1, 0])
|
||||
alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t)
|
||||
alpha = F.leaky_relu(alpha, 0.2)
|
||||
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
|
||||
alpha = alpha.permute(0, 2, 1)
|
||||
|
||||
if self.innovate:
|
||||
inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2)
|
||||
inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0])))
|
||||
inno_alpha = inno_alpha.permute(0, 2, 1)
|
||||
|
||||
inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2)
|
||||
inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0])))
|
||||
inno_x = inno_x.permute(0, 2, 1)
|
||||
|
||||
# signal path
|
||||
y = x.reshape(batch_size, num_frames, -1)
|
||||
y = alpha * y
|
||||
|
||||
if self.innovate:
|
||||
y = y + inno_alpha * inno_x
|
||||
|
||||
return y.reshape(batch_size, 1, num_samples)
|
||||
@@ -0,0 +1,112 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def load_lpcnet_features(feature_file, version=2):
|
||||
if version == 2:
|
||||
layout = {
|
||||
'cepstrum': [0,18],
|
||||
'periods': [18, 19],
|
||||
'pitch_corr': [19, 20],
|
||||
'lpc': [20, 36]
|
||||
}
|
||||
frame_length = 36
|
||||
|
||||
elif version == 1:
|
||||
layout = {
|
||||
'cepstrum': [0,18],
|
||||
'periods': [36, 37],
|
||||
'pitch_corr': [37, 38],
|
||||
'lpc': [39, 55],
|
||||
}
|
||||
frame_length = 55
|
||||
else:
|
||||
raise ValueError(f'unknown feature version: {version}')
|
||||
|
||||
|
||||
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
|
||||
raw_features = raw_features.reshape((-1, frame_length))
|
||||
|
||||
features = torch.cat(
|
||||
[
|
||||
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
|
||||
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
|
||||
],
|
||||
dim=1
|
||||
)
|
||||
|
||||
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
|
||||
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
|
||||
|
||||
|
||||
|
||||
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
|
||||
ref_data = np.memmap(reference_data_path, dtype=np.int16)
|
||||
signal = np.memmap(signal_path, dtype=np.int16)
|
||||
|
||||
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
|
||||
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
|
||||
|
||||
|
||||
assert len(signal) % 160 == 0
|
||||
num_frames = len(signal) // 160
|
||||
mem = np.zeros(1)
|
||||
for fr in range(len(signal)//160):
|
||||
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
|
||||
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
|
||||
|
||||
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
|
||||
|
||||
new_data[:] = 0
|
||||
N = len(signal) - offset
|
||||
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
|
||||
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
|
||||
|
||||
|
||||
def parse_warpq_scores(output_file):
|
||||
""" extracts warpq scores from output file """
|
||||
|
||||
with open(output_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def parse_stats_file(file):
|
||||
|
||||
with open(file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
mean = float(lines[0].split(":")[-1])
|
||||
bt_mean = float(lines[1].split(":")[-1])
|
||||
top_mean = float(lines[2].split(":")[-1])
|
||||
|
||||
return mean, bt_mean, top_mean
|
||||
|
||||
def collect_test_stats(test_folder):
|
||||
""" collects statistics for all discovered metrics from test folder """
|
||||
|
||||
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
|
||||
|
||||
results = dict()
|
||||
|
||||
content = os.listdir(test_folder)
|
||||
|
||||
stats_files = [file for file in content if file.startswith('stats_')]
|
||||
|
||||
for file in stats_files:
|
||||
metric = file[len("stats_") : -len(".txt")]
|
||||
|
||||
if metric not in metrics:
|
||||
print(f"warning: unknown metric {metric}")
|
||||
|
||||
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
|
||||
|
||||
results[metric] = [mean, bt_mean, top_mean]
|
||||
|
||||
return results
|
||||
95
managed_components/78__esp-opus/dnn/torch/osce/utils/misc.py
Normal file
95
managed_components/78__esp-opus/dnn/torch/osce/utils/misc.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
|
||||
def count_parameters(model, verbose=False):
|
||||
total = 0
|
||||
for name, p in model.named_parameters():
|
||||
count = torch.ones_like(p).sum().item()
|
||||
|
||||
if verbose:
|
||||
print(f"{name}: {count} parameters")
|
||||
|
||||
total += count
|
||||
|
||||
return total
|
||||
|
||||
def count_nonzero_parameters(model, verbose=False):
|
||||
total = 0
|
||||
for name, p in model.named_parameters():
|
||||
count = torch.count_nonzero(p).item()
|
||||
|
||||
if verbose:
|
||||
print(f"{name}: {count} non-zero parameters")
|
||||
|
||||
total += count
|
||||
|
||||
return total
|
||||
def retain_grads(module):
|
||||
for p in module.parameters():
|
||||
if p.requires_grad:
|
||||
p.retain_grad()
|
||||
|
||||
def get_grad_norm(module, p=2):
|
||||
norm = 0
|
||||
for param in module.parameters():
|
||||
if param.requires_grad:
|
||||
norm = norm + (torch.abs(param.grad) ** p).sum()
|
||||
|
||||
return norm ** (1/p)
|
||||
|
||||
def create_weights(s_real, s_gen, alpha):
|
||||
weights = []
|
||||
with torch.no_grad():
|
||||
for sr, sg in zip(s_real, s_gen):
|
||||
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
|
||||
weights.append(weight)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def _get_candidates(module: torch.nn.Module):
|
||||
candidates = []
|
||||
for key in module.__dict__.keys():
|
||||
if hasattr(module, key + '_v'):
|
||||
candidates.append(key)
|
||||
return candidates
|
||||
|
||||
def remove_all_weight_norm(model : torch.nn.Module, verbose=False):
|
||||
for name, m in model.named_modules():
|
||||
candidates = _get_candidates(m)
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
remove_weight_norm(m, name=candidate)
|
||||
if verbose: print(f'removed weight norm on weight {name}.{candidate}')
|
||||
except:
|
||||
pass
|
||||
153
managed_components/78__esp-opus/dnn/torch/osce/utils/moc.py
Normal file
153
managed_components/78__esp-opus/dnn/torch/osce/utils/moc.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
|
||||
def compute_vad_mask(x, fs, stop_db=-70):
|
||||
|
||||
frame_length = (fs + 49) // 50
|
||||
x = x[: frame_length * (len(x) // frame_length)]
|
||||
|
||||
frames = x.reshape(-1, frame_length)
|
||||
frame_energy = np.sum(frames ** 2, axis=1)
|
||||
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
|
||||
|
||||
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
|
||||
vactive = np.ones_like(frames)
|
||||
vactive[frame_energy_smooth < max_threshold, :] = 0
|
||||
vactive = vactive.reshape(-1)
|
||||
|
||||
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
|
||||
filter = filter / filter.sum()
|
||||
|
||||
mask = np.convolve(vactive, filter, mode='same')
|
||||
|
||||
return x, mask
|
||||
|
||||
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
|
||||
num_samples = frame_size + (num_frames - 1) * hop_size
|
||||
if len(mask) < num_samples:
|
||||
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
|
||||
else:
|
||||
mask = mask[:num_samples]
|
||||
|
||||
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
|
||||
|
||||
return new_mask
|
||||
|
||||
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
|
||||
num_spectra = (len(x) - window_size - hop_size) // hop_size
|
||||
window = scipy.signal.get_window(window, window_size)
|
||||
N = window_size // 2
|
||||
|
||||
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
|
||||
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
|
||||
|
||||
return psd
|
||||
|
||||
|
||||
def frequency_mask(num_bands, up_factor, down_factor):
|
||||
|
||||
up_mask = np.zeros((num_bands, num_bands))
|
||||
down_mask = np.zeros((num_bands, num_bands))
|
||||
|
||||
for i in range(num_bands):
|
||||
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
|
||||
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
|
||||
|
||||
return down_mask @ up_mask
|
||||
|
||||
|
||||
def rect_fb(band_limits, num_bins=None):
|
||||
num_bands = len(band_limits) - 1
|
||||
if num_bins is None:
|
||||
num_bins = band_limits[-1]
|
||||
|
||||
fb = np.zeros((num_bands, num_bins))
|
||||
for i in range(num_bands):
|
||||
fb[i, band_limits[i]:band_limits[i+1]] = 1
|
||||
|
||||
return fb
|
||||
|
||||
|
||||
def compare(x, y, apply_vad=False):
|
||||
""" Modified version of opus_compare for 16 kHz mono signals
|
||||
|
||||
Args:
|
||||
x (np.ndarray): reference input signal scaled to [-1, 1]
|
||||
y (np.ndarray): test signal scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
float: perceptually weighted error
|
||||
"""
|
||||
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
|
||||
band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
|
||||
num_bands = len(band_limits) - 1
|
||||
fb = rect_fb(band_limits, num_bins=81)
|
||||
|
||||
# trim samples to same size
|
||||
num_samples = min(len(x), len(y))
|
||||
x = x[:num_samples] * 2**15
|
||||
y = y[:num_samples] * 2**15
|
||||
|
||||
psd_x = power_spectrum(x) + 100000
|
||||
psd_y = power_spectrum(y) + 100000
|
||||
|
||||
num_frames = psd_x.shape[0]
|
||||
|
||||
# average band energies
|
||||
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
|
||||
|
||||
# frequecy masking
|
||||
f_mask = frequency_mask(num_bands, 0.1, 0.03)
|
||||
mask_x = be_x @ f_mask.T
|
||||
|
||||
# temporal masking
|
||||
for i in range(1, num_frames):
|
||||
mask_x[i, :] += 0.5 * mask_x[i-1, :]
|
||||
|
||||
# apply mask
|
||||
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
|
||||
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
|
||||
|
||||
# 2-frame average
|
||||
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
|
||||
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
|
||||
|
||||
# distortion metric
|
||||
re = masked_psd_y / masked_psd_x
|
||||
im = np.log(re) ** 2
|
||||
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
|
||||
Ef = np.mean(Eb , axis=1)
|
||||
|
||||
if apply_vad:
|
||||
_, mask = compute_vad_mask(x, 16000)
|
||||
mask = convert_mask(mask, Ef.shape[0])
|
||||
else:
|
||||
mask = np.ones_like(Ef)
|
||||
|
||||
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
|
||||
|
||||
return float(err)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from scipy.io import wavfile
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('ref', type=str, help='reference wav file')
|
||||
parser.add_argument('deg', type=str, help='degraded wav file')
|
||||
parser.add_argument('--apply-vad', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
fs1, x = wavfile.read(args.ref)
|
||||
fs2, y = wavfile.read(args.deg)
|
||||
|
||||
if max(fs1, fs2) != 16000:
|
||||
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
|
||||
|
||||
x = x.astype(np.float32) / 2**15
|
||||
y = y.astype(np.float32) / 2**15
|
||||
|
||||
err = compare(x, y, apply_vad=args.apply_vad)
|
||||
|
||||
print(f"MOC: {err}")
|
||||
122
managed_components/78__esp-opus/dnn/torch/osce/utils/pitch.py
Normal file
122
managed_components/78__esp-opus/dnn/torch/osce/utils/pitch.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
def hangover(lags, num_frames=10):
|
||||
lags = lags.copy()
|
||||
count = 0
|
||||
last_lag = 0
|
||||
|
||||
for i in range(len(lags)):
|
||||
lag = lags[i]
|
||||
|
||||
if lag == 0:
|
||||
if count < num_frames:
|
||||
lags[i] = last_lag
|
||||
count += 1
|
||||
else:
|
||||
count = 0
|
||||
last_lag = lag
|
||||
|
||||
return lags
|
||||
|
||||
|
||||
def smooth_pitch_lags(lags, d=2):
|
||||
|
||||
assert d < 4
|
||||
|
||||
num_silk_frames = len(lags) // 4
|
||||
|
||||
smoothed_lags = lags.copy()
|
||||
|
||||
tmp = np.arange(1, d+1)
|
||||
kernel = np.concatenate((tmp, [d+1], tmp[::-1]), dtype=np.float32)
|
||||
kernel = kernel / np.sum(kernel)
|
||||
|
||||
last = lags[0:d][::-1]
|
||||
for i in range(num_silk_frames):
|
||||
frame = lags[i * 4: (i+1) * 4]
|
||||
|
||||
if np.max(np.abs(frame)) == 0:
|
||||
last = frame[4-d:]
|
||||
continue
|
||||
|
||||
if i == num_silk_frames - 1:
|
||||
next = frame[4-d:][::-1]
|
||||
else:
|
||||
next = lags[(i+1) * 4 : (i+1) * 4 + d]
|
||||
|
||||
if np.max(np.abs(next)) == 0:
|
||||
next = frame[4-d:][::-1]
|
||||
|
||||
if np.max(np.abs(last)) == 0:
|
||||
last = frame[0:d][::-1]
|
||||
|
||||
smoothed_frame = np.convolve(np.concatenate((last, frame, next), dtype=np.float32), kernel, mode='valid')
|
||||
|
||||
smoothed_lags[i * 4: (i+1) * 4] = np.round(smoothed_frame)
|
||||
|
||||
last = frame[4-d:]
|
||||
|
||||
return smoothed_lags
|
||||
|
||||
def calculate_acorr_window(x, frame_size, lags, history=None, max_lag=300, radius=2, add_double_lag_acorr=False, no_pitch_threshold=32):
|
||||
eps = 1e-9
|
||||
|
||||
lag_multiplier = 2 if add_double_lag_acorr else 1
|
||||
|
||||
if history is None:
|
||||
history = np.zeros(lag_multiplier * max_lag + radius, dtype=x.dtype)
|
||||
|
||||
offset = len(history)
|
||||
|
||||
assert offset >= max_lag + radius
|
||||
assert len(x) % frame_size == 0
|
||||
|
||||
num_frames = len(x) // frame_size
|
||||
lags = lags.copy()
|
||||
|
||||
x_ext = np.concatenate((history, x), dtype=x.dtype)
|
||||
|
||||
d = radius
|
||||
num_acorrs = 2 * d + 1
|
||||
acorrs = np.zeros((num_frames, lag_multiplier * num_acorrs), dtype=x.dtype)
|
||||
|
||||
for idx in range(num_frames):
|
||||
lag = lags[idx].item()
|
||||
frame = x_ext[offset + idx * frame_size : offset + (idx + 1) * frame_size]
|
||||
|
||||
for k in range(lag_multiplier):
|
||||
lag1 = (k + 1) * lag if lag >= no_pitch_threshold else lag
|
||||
for j in range(num_acorrs):
|
||||
past = x_ext[offset + idx * frame_size - lag1 + j - d : offset + (idx + 1) * frame_size - lag1 + j - d]
|
||||
acorrs[idx, j + k * num_acorrs] = np.dot(frame, past) / np.sqrt(np.dot(frame, frame) * np.dot(past, past) + eps)
|
||||
|
||||
return acorrs, lags
|
||||
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import scipy
|
||||
import scipy.signal
|
||||
|
||||
from utils.pitch import hangover, calculate_acorr_window
|
||||
from utils.spec import create_filter_bank, cepstrum, log_spectrum, log_spectrum_from_lpc
|
||||
|
||||
def spec_from_lpc(a, n_fft=128, eps=1e-9):
|
||||
order = a.shape[-1]
|
||||
assert order + 1 < n_fft
|
||||
|
||||
x = np.zeros((*a.shape[:-1], n_fft ))
|
||||
x[..., 0] = 1
|
||||
x[..., 1:1 + order] = -a
|
||||
|
||||
X = np.fft.fft(x, axis=-1)
|
||||
X = np.abs(X[..., :n_fft//2 + 1]) ** 2
|
||||
|
||||
S = 1 / (X + eps)
|
||||
|
||||
return S
|
||||
|
||||
def silk_feature_factory(no_pitch_value=256,
|
||||
acorr_radius=2,
|
||||
pitch_hangover=8,
|
||||
num_bands_clean_spec=64,
|
||||
num_bands_noisy_spec=18,
|
||||
noisy_spec_scale='opus',
|
||||
noisy_apply_dct=True,
|
||||
add_double_lag_acorr=False
|
||||
):
|
||||
|
||||
w = scipy.signal.windows.cosine(320)
|
||||
fb_clean_spec = create_filter_bank(num_bands_clean_spec, 320, scale='erb', round_center_bins=True, normalize=True)
|
||||
fb_noisy_spec = create_filter_bank(num_bands_noisy_spec, 320, scale=noisy_spec_scale, round_center_bins=True, normalize=True)
|
||||
|
||||
def create_features(noisy, noisy_history, lpcs, gains, ltps, periods):
|
||||
|
||||
periods = periods.copy()
|
||||
|
||||
if pitch_hangover > 0:
|
||||
periods = hangover(periods, num_frames=pitch_hangover)
|
||||
|
||||
periods[periods == 0] = no_pitch_value
|
||||
|
||||
clean_spectrum = 0.3 * log_spectrum_from_lpc(lpcs, fb=fb_clean_spec, n_fft=320)
|
||||
|
||||
if noisy_apply_dct:
|
||||
noisy_cepstrum = np.repeat(
|
||||
cepstrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
|
||||
else:
|
||||
noisy_cepstrum = np.repeat(
|
||||
log_spectrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
|
||||
|
||||
log_gains = np.log(gains + 1e-9).reshape(-1, 1)
|
||||
|
||||
acorr, _ = calculate_acorr_window(noisy, 80, periods, noisy_history, radius=acorr_radius, add_double_lag_acorr=add_double_lag_acorr)
|
||||
|
||||
features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains), axis=-1, dtype=np.float32)
|
||||
|
||||
return features, periods.astype(np.int64)
|
||||
|
||||
return create_features
|
||||
|
||||
|
||||
|
||||
def load_inference_data(path,
|
||||
no_pitch_value=256,
|
||||
skip=92,
|
||||
preemph=0.85,
|
||||
acorr_radius=2,
|
||||
pitch_hangover=8,
|
||||
num_bands_clean_spec=64,
|
||||
num_bands_noisy_spec=18,
|
||||
noisy_spec_scale='opus',
|
||||
noisy_apply_dct=True,
|
||||
add_double_lag_acorr=False,
|
||||
**kwargs):
|
||||
|
||||
print(f"[load_inference_data]: ignoring keyword arguments {kwargs.keys()}...")
|
||||
|
||||
lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
|
||||
ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
|
||||
gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
|
||||
periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
|
||||
num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32).astype(np.float32).reshape(-1, 1)
|
||||
num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32).reshape(-1, 1)
|
||||
|
||||
# load signal, add back delay and pre-emphasize
|
||||
signal = np.fromfile(os.path.join(path, 'noisy.s16'), dtype=np.int16).astype(np.float32) / (2 ** 15)
|
||||
signal = np.concatenate((np.zeros(skip, dtype=np.float32), signal), dtype=np.float32)
|
||||
|
||||
create_features = silk_feature_factory(no_pitch_value, acorr_radius, pitch_hangover, num_bands_clean_spec, num_bands_noisy_spec, noisy_spec_scale, noisy_apply_dct, add_double_lag_acorr)
|
||||
|
||||
num_frames = min((len(signal) // 320) * 4, len(lpcs))
|
||||
signal = signal[: num_frames * 80]
|
||||
lpcs = lpcs[: num_frames]
|
||||
ltps = ltps[: num_frames]
|
||||
gains = gains[: num_frames]
|
||||
periods = periods[: num_frames]
|
||||
num_bits = num_bits[: num_frames // 4]
|
||||
num_bits_smooth = num_bits[: num_frames // 4]
|
||||
|
||||
numbits = np.repeat(np.concatenate((num_bits, num_bits_smooth), axis=-1, dtype=np.float32), 4, axis=0)
|
||||
|
||||
features, periods = create_features(signal, np.zeros(350, dtype=signal.dtype), lpcs, gains, ltps, periods)
|
||||
|
||||
if preemph > 0:
|
||||
signal[1:] -= preemph * signal[:-1]
|
||||
|
||||
return torch.from_numpy(signal), torch.from_numpy(features), torch.from_numpy(periods), torch.from_numpy(numbits)
|
||||
@@ -0,0 +1,110 @@
|
||||
import torch
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_optimal_scale(weight):
|
||||
with torch.no_grad():
|
||||
n_out, n_in = weight.shape
|
||||
assert n_in % 4 == 0
|
||||
if n_out % 8:
|
||||
# add padding
|
||||
pad = n_out - n_out % 8
|
||||
weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
|
||||
|
||||
weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
|
||||
weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
|
||||
scale_max = weight_max_abs / 127
|
||||
scale_sum = weight_max_sum / 129
|
||||
|
||||
scale = torch.maximum(scale_max, scale_sum)
|
||||
|
||||
return scale[:n_out]
|
||||
|
||||
@torch.no_grad()
|
||||
def q_scaled_noise(module, weight):
|
||||
if isinstance(module, torch.nn.Conv1d):
|
||||
w = weight.permute(0, 2, 1).flatten(1)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
|
||||
elif isinstance(module, torch.nn.ConvTranspose1d):
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(2, 1, 0).reshape(k * o, i)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(k, o, i).permute(2, 1, 0)
|
||||
elif len(weight.shape) == 2:
|
||||
noise = torch.rand_like(weight) - 0.5
|
||||
scale = compute_optimal_scale(weight)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError('unknown quantization setting')
|
||||
|
||||
return noise
|
||||
|
||||
class SoftQuant:
|
||||
name: str
|
||||
|
||||
def __init__(self, names: str, scale: float) -> None:
|
||||
self.names = names
|
||||
self.quantization_noise = None
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, module, inputs, *args, before=True):
|
||||
if not module.training: return
|
||||
|
||||
if before:
|
||||
self.quantization_noise = dict()
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
if self.scale is None:
|
||||
self.quantization_noise[name] = q_scaled_noise(module, weight)
|
||||
else:
|
||||
self.quantization_noise[name] = \
|
||||
self.scale * weight.abs().max() * (torch.rand_like(weight) - 0.5)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight + self.quantization_noise[name]
|
||||
else:
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight - self.quantization_noise[name]
|
||||
self.quantization_noise = None
|
||||
|
||||
def apply(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant(names, scale)
|
||||
|
||||
for name in names:
|
||||
if not hasattr(module, name):
|
||||
raise ValueError("")
|
||||
|
||||
fn_before = lambda *x : fn(*x, before=True)
|
||||
fn_after = lambda *x : fn(*x, before=False)
|
||||
setattr(fn_before, 'sqm', fn)
|
||||
setattr(fn_after, 'sqm', fn)
|
||||
|
||||
|
||||
module.register_forward_pre_hook(fn_before)
|
||||
module.register_forward_hook(fn_after)
|
||||
|
||||
module
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def soft_quant(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant.apply(module, names, scale)
|
||||
return module
|
||||
|
||||
def remove_soft_quant(module, names=['weight']):
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_pre_hooks[k]
|
||||
for k, hook in module._forward_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_hooks[k]
|
||||
|
||||
return module
|
||||
210
managed_components/78__esp-opus/dnn/torch/osce/utils/spec.py
Normal file
210
managed_components/78__esp-opus/dnn/torch/osce/utils/spec.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import math as m
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.fftpack
|
||||
import torch
|
||||
|
||||
def erb(f):
|
||||
return 24.7 * (4.37 * f + 1)
|
||||
|
||||
def inv_erb(e):
|
||||
return (e / 24.7 - 1) / 4.37
|
||||
|
||||
def bark(f):
|
||||
return 6 * m.asinh(f/600)
|
||||
|
||||
def inv_bark(b):
|
||||
return 600 * m.sinh(b / 6)
|
||||
|
||||
|
||||
scale_dict = {
|
||||
'bark': [bark, inv_bark],
|
||||
'erb': [erb, inv_erb]
|
||||
}
|
||||
|
||||
def gen_filterbank(N, Fs=16000, keep_size=False):
|
||||
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
|
||||
M = N + 1 if keep_size else N
|
||||
out_freq = (np.arange(M, dtype='float32')/N*Fs/2)[:,None]
|
||||
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
|
||||
ERB_N = 24.7 + .108*in_freq
|
||||
delta = np.abs(in_freq-out_freq)/ERB_N
|
||||
center = (delta<.5).astype('float32')
|
||||
R = -12*center*delta**2 + (1-center)*(3-12*delta)
|
||||
RE = 10.**(R/10.)
|
||||
norm = np.sum(RE, axis=1)
|
||||
RE = RE/norm[:, np.newaxis]
|
||||
return torch.from_numpy(RE)
|
||||
|
||||
def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
|
||||
|
||||
f0 = 0
|
||||
num_bins = n_fft // 2 + 1
|
||||
f1 = fs / n_fft * (num_bins - 1)
|
||||
fstep = fs / n_fft
|
||||
|
||||
if scale == 'opus':
|
||||
bins_5ms = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40]
|
||||
fac = 1000 * n_fft / fs / 5
|
||||
if num_bands != 18:
|
||||
print("warning: requested Opus filter bank with num_bands != 18. Adjusting num_bands.")
|
||||
num_bands = 18
|
||||
center_bins = np.array([fac * bin for bin in bins_5ms])
|
||||
else:
|
||||
to_scale, from_scale = scale_dict[scale]
|
||||
|
||||
s0 = to_scale(f0)
|
||||
s1 = to_scale(f1)
|
||||
|
||||
center_freqs = np.array([f0] + [from_scale(s0 + i * (s1 - s0) / (num_bands)) for i in range(1, num_bands - 1)] + [f1])
|
||||
center_bins = (center_freqs - f0) / fstep
|
||||
|
||||
if round_center_bins:
|
||||
center_bins = np.round(center_bins)
|
||||
|
||||
filter_bank = np.zeros((num_bands, num_bins))
|
||||
|
||||
band = 0
|
||||
for bin in range(num_bins):
|
||||
# update band index
|
||||
if bin > center_bins[band + 1]:
|
||||
band += 1
|
||||
|
||||
# calculate filter coefficients
|
||||
frac = (center_bins[band + 1] - bin) / (center_bins[band + 1] - center_bins[band])
|
||||
filter_bank[band][bin] = frac
|
||||
filter_bank[band + 1][bin] = 1 - frac
|
||||
|
||||
if return_upper:
|
||||
extend = n_fft - num_bins
|
||||
filter_bank = np.concatenate((filter_bank, np.fliplr(filter_bank[:, 1:extend+1])), axis=1)
|
||||
|
||||
if normalize:
|
||||
filter_bank = filter_bank / np.sum(filter_bank, axis=1).reshape(-1, 1)
|
||||
|
||||
return filter_bank
|
||||
|
||||
|
||||
def compressed_log_spec(pspec):
|
||||
|
||||
lpspec = np.zeros_like(pspec)
|
||||
num_bands = pspec.shape[-1]
|
||||
|
||||
log_max = -2
|
||||
follow = -2
|
||||
|
||||
for i in range(num_bands):
|
||||
tmp = np.log10(pspec[i] + 1e-9)
|
||||
tmp = max(log_max, max(follow - 2.5, tmp))
|
||||
lpspec[i] = tmp
|
||||
log_max = max(log_max, tmp)
|
||||
follow = max(follow - 2.5, tmp)
|
||||
|
||||
return lpspec
|
||||
|
||||
def log_spectrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False, power=1):
|
||||
""" calculates cepstrum from SILK lpcs """
|
||||
order = a.shape[-1]
|
||||
assert order + 1 < n_fft
|
||||
|
||||
a = a * (gamma ** (1 + np.arange(order))).astype(np.float32)
|
||||
|
||||
x = np.zeros((*a.shape[:-1], n_fft ))
|
||||
x[..., 0] = 1
|
||||
x[..., 1:1 + order] = -a
|
||||
|
||||
X = np.fft.fft(x, axis=-1)
|
||||
X = np.abs(X[..., :n_fft//2 + 1]) ** power
|
||||
|
||||
S = 1 / (X + eps)
|
||||
|
||||
if fb is None:
|
||||
Sf = S
|
||||
else:
|
||||
Sf = np.matmul(S, fb.T)
|
||||
|
||||
if compress:
|
||||
Sf = np.apply_along_axis(compressed_log_spec, -1, Sf)
|
||||
else:
|
||||
Sf = np.log(Sf + eps)
|
||||
|
||||
return Sf
|
||||
|
||||
def cepstrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False):
|
||||
""" calculates cepstrum from SILK lpcs """
|
||||
|
||||
Sf = log_spectrum_from_lpc(a, fb, n_fft, eps, gamma, compress)
|
||||
|
||||
cepstrum = scipy.fftpack.dct(Sf, 2, norm='ortho')
|
||||
|
||||
return cepstrum
|
||||
|
||||
|
||||
|
||||
def log_spectrum(x, frame_size, fb=None, window=None, power=1):
|
||||
""" calculate cepstrum on 50% overlapping frames """
|
||||
|
||||
assert(2*len(x)) % frame_size == 0
|
||||
assert frame_size % 2 == 0
|
||||
|
||||
n = len(x)
|
||||
num_even = n // frame_size
|
||||
num_odd = (n - frame_size // 2) // frame_size
|
||||
num_bins = frame_size // 2 + 1
|
||||
|
||||
x_even = x[:num_even * frame_size].reshape(-1, frame_size)
|
||||
x_odd = x[frame_size//2 : frame_size//2 + frame_size * num_odd].reshape(-1, frame_size)
|
||||
|
||||
x_unfold = np.empty((x_even.size + x_odd.size), dtype=x.dtype).reshape((-1, frame_size))
|
||||
x_unfold[::2, :] = x_even
|
||||
x_unfold[1::2, :] = x_odd
|
||||
|
||||
if window is not None:
|
||||
x_unfold *= window.reshape(1, -1)
|
||||
|
||||
X = np.abs(np.fft.fft(x_unfold, n=frame_size, axis=-1))[:, :num_bins] ** power
|
||||
|
||||
if fb is not None:
|
||||
X = np.matmul(X, fb.T)
|
||||
|
||||
|
||||
return np.log(X + 1e-9)
|
||||
|
||||
|
||||
def cepstrum(x, frame_size, fb=None, window=None):
|
||||
""" calculate cepstrum on 50% overlapping frames """
|
||||
|
||||
X = log_spectrum(x, frame_size, fb, window)
|
||||
|
||||
cepstrum = scipy.fftpack.dct(X, 2, norm='ortho')
|
||||
|
||||
return cepstrum
|
||||
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
setup_dict = dict()
|
||||
|
||||
lace_setup = {
|
||||
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
|
||||
'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
|
||||
'model': {
|
||||
'name': 'lace',
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'comb_gain_limit_db': 10,
|
||||
'cond_dim': 128,
|
||||
'conv_gain_limits_db': [-12, 12],
|
||||
'global_gain_limits_db': [-6, 6],
|
||||
'hidden_feature_dim': 96,
|
||||
'kernel_size': 15,
|
||||
'num_features': 93,
|
||||
'numbits_embedding_dim': 8,
|
||||
'numbits_range': [50, 650],
|
||||
'partial_lookahead': True,
|
||||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'softquant': True,
|
||||
'sparsify': False,
|
||||
'sparsification_density': 0.4,
|
||||
'sparsification_schedule': [10000, 40000, 200]
|
||||
}
|
||||
},
|
||||
'data': {
|
||||
'frames_per_sample': 100,
|
||||
'no_pitch_value': 7,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'pitch_hangover': 8,
|
||||
'acorr_radius': 2,
|
||||
'num_bands_clean_spec': 64,
|
||||
'num_bands_noisy_spec': 18,
|
||||
'noisy_spec_scale': 'opus',
|
||||
'pitch_hangover': 0,
|
||||
},
|
||||
'training': {
|
||||
'batch_size': 256,
|
||||
'lr': 5.e-4,
|
||||
'lr_decay_factor': 2.5e-5,
|
||||
'epochs': 50,
|
||||
'loss': {
|
||||
'w_l1': 0,
|
||||
'w_lm': 0,
|
||||
'w_logmel': 0,
|
||||
'w_sc': 0,
|
||||
'w_wsc': 0,
|
||||
'w_xcorr': 0,
|
||||
'w_sxcorr': 1,
|
||||
'w_l2': 10,
|
||||
'w_slm': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
nolace_setup = {
|
||||
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
|
||||
'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
|
||||
'model': {
|
||||
'name': 'nolace',
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'avg_pool_k': 4,
|
||||
'comb_gain_limit_db': 10,
|
||||
'cond_dim': 256,
|
||||
'conv_gain_limits_db': [-12, 12],
|
||||
'global_gain_limits_db': [-6, 6],
|
||||
'hidden_feature_dim': 96,
|
||||
'kernel_size': 15,
|
||||
'num_features': 93,
|
||||
'numbits_embedding_dim': 8,
|
||||
'numbits_range': [50, 650],
|
||||
'partial_lookahead': True,
|
||||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'softquant': True,
|
||||
'sparsify': False,
|
||||
'sparsification_density': 0.4,
|
||||
'sparsification_schedule': [10000, 40000, 200]
|
||||
}
|
||||
},
|
||||
'data': {
|
||||
'frames_per_sample': 100,
|
||||
'no_pitch_value': 7,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'pitch_hangover': 8,
|
||||
'acorr_radius': 2,
|
||||
'num_bands_clean_spec': 64,
|
||||
'num_bands_noisy_spec': 18,
|
||||
'noisy_spec_scale': 'opus',
|
||||
'pitch_hangover': 0,
|
||||
},
|
||||
'training': {
|
||||
'batch_size': 256,
|
||||
'lr': 5.e-4,
|
||||
'lr_decay_factor': 2.5e-5,
|
||||
'epochs': 50,
|
||||
'loss': {
|
||||
'w_l1': 0,
|
||||
'w_lm': 0,
|
||||
'w_logmel': 0,
|
||||
'w_sc': 0,
|
||||
'w_wsc': 0,
|
||||
'w_xcorr': 0,
|
||||
'w_sxcorr': 1,
|
||||
'w_l2': 10,
|
||||
'w_slm': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nolace_setup_adv = {
|
||||
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
|
||||
'model': {
|
||||
'name': 'nolace',
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'avg_pool_k': 4,
|
||||
'comb_gain_limit_db': 10,
|
||||
'cond_dim': 256,
|
||||
'conv_gain_limits_db': [-12, 12],
|
||||
'global_gain_limits_db': [-6, 6],
|
||||
'hidden_feature_dim': 96,
|
||||
'kernel_size': 15,
|
||||
'num_features': 93,
|
||||
'numbits_embedding_dim': 8,
|
||||
'numbits_range': [50, 650],
|
||||
'partial_lookahead': True,
|
||||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'softquant': True,
|
||||
'sparsify': False,
|
||||
'sparsification_density': 0.4,
|
||||
'sparsification_schedule': [0, 0, 200]
|
||||
}
|
||||
},
|
||||
'data': {
|
||||
'frames_per_sample': 100,
|
||||
'no_pitch_value': 7,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'pitch_hangover': 8,
|
||||
'acorr_radius': 2,
|
||||
'num_bands_clean_spec': 64,
|
||||
'num_bands_noisy_spec': 18,
|
||||
'noisy_spec_scale': 'opus',
|
||||
'pitch_hangover': 0,
|
||||
},
|
||||
'discriminator': {
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'architecture': 'free',
|
||||
'design': 'f_down',
|
||||
'fft_sizes_16k': [
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
],
|
||||
'freq_roi': [0, 7400],
|
||||
'fs': 16000,
|
||||
'max_channels': 256,
|
||||
'noise_gain': 0.0,
|
||||
},
|
||||
'name': 'fdmresdisc',
|
||||
},
|
||||
'training': {
|
||||
'adv_target': 'target_orig',
|
||||
'batch_size': 64,
|
||||
'epochs': 50,
|
||||
'gen_lr_reduction': 1,
|
||||
'lambda_feat': 1.0,
|
||||
'lambda_reg': 0.6,
|
||||
'loss': {
|
||||
'w_l1': 0,
|
||||
'w_l2': 10,
|
||||
'w_lm': 0,
|
||||
'w_logmel': 0,
|
||||
'w_sc': 0,
|
||||
'w_slm': 20,
|
||||
'w_sxcorr': 1,
|
||||
'w_wsc': 0,
|
||||
'w_xcorr': 0,
|
||||
},
|
||||
'lr': 0.0001,
|
||||
'lr_decay_factor': 2.5e-09,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
lavoce_setup = {
|
||||
'data': {
|
||||
'frames_per_sample': 100,
|
||||
'target': 'signal'
|
||||
},
|
||||
'dataset': '/local/datasets/lpcnet_large/training',
|
||||
'model': {
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'comb_gain_limit_db': 10,
|
||||
'cond_dim': 256,
|
||||
'conv_gain_limits_db': [-12, 12],
|
||||
'global_gain_limits_db': [-6, 6],
|
||||
'kernel_size': 15,
|
||||
'num_features': 19,
|
||||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'pulses': True
|
||||
},
|
||||
'name': 'lavoce'
|
||||
},
|
||||
'training': {
|
||||
'batch_size': 256,
|
||||
'epochs': 50,
|
||||
'loss': {
|
||||
'w_l1': 0,
|
||||
'w_l2': 0,
|
||||
'w_lm': 0,
|
||||
'w_logmel': 0,
|
||||
'w_sc': 0,
|
||||
'w_slm': 2,
|
||||
'w_sxcorr': 1,
|
||||
'w_wsc': 0,
|
||||
'w_xcorr': 0
|
||||
},
|
||||
'lr': 0.0005,
|
||||
'lr_decay_factor': 2.5e-05
|
||||
},
|
||||
'validation_dataset': '/local/datasets/lpcnet_large/validation'
|
||||
}
|
||||
|
||||
lavoce_setup_adv = {
|
||||
'data': {
|
||||
'frames_per_sample': 100,
|
||||
'target': 'signal'
|
||||
},
|
||||
'dataset': '/local/datasets/lpcnet_large/training',
|
||||
'discriminator': {
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'architecture': 'free',
|
||||
'design': 'f_down',
|
||||
'fft_sizes_16k': [
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
],
|
||||
'freq_roi': [0, 7400],
|
||||
'fs': 16000,
|
||||
'max_channels': 256,
|
||||
'noise_gain': 0.0,
|
||||
},
|
||||
'name': 'fdmresdisc',
|
||||
},
|
||||
'model': {
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'comb_gain_limit_db': 10,
|
||||
'cond_dim': 256,
|
||||
'conv_gain_limits_db': [-12, 12],
|
||||
'global_gain_limits_db': [-6, 6],
|
||||
'kernel_size': 15,
|
||||
'num_features': 19,
|
||||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'pulses': True
|
||||
},
|
||||
'name': 'lavoce'
|
||||
},
|
||||
'training': {
|
||||
'batch_size': 64,
|
||||
'epochs': 50,
|
||||
'gen_lr_reduction': 1,
|
||||
'lambda_feat': 1.0,
|
||||
'lambda_reg': 0.6,
|
||||
'loss': {
|
||||
'w_l1': 0,
|
||||
'w_l2': 0,
|
||||
'w_lm': 0,
|
||||
'w_logmel': 0,
|
||||
'w_sc': 0,
|
||||
'w_slm': 2,
|
||||
'w_sxcorr': 1,
|
||||
'w_wsc': 0,
|
||||
'w_xcorr': 0
|
||||
},
|
||||
'lr': 0.0001,
|
||||
'lr_decay_factor': 2.5e-09
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
setup_dict = {
|
||||
'lace': lace_setup,
|
||||
'nolace': nolace_setup,
|
||||
'nolace_adv': nolace_setup_adv,
|
||||
'lavoce': lavoce_setup,
|
||||
'lavoce_adv': lavoce_setup_adv
|
||||
}
|
||||
Reference in New Issue
Block a user