add some code
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
from .lace import LACE
|
||||
from .no_lace import NoLACE
|
||||
from .lavoce import LaVoce
|
||||
from .lavoce_400 import LaVoce400
|
||||
from .fd_discriminator import TFDMultiResolutionDiscriminator as FDMResDisc
|
||||
|
||||
model_dict = {
|
||||
'lace': LACE,
|
||||
'nolace': NoLACE,
|
||||
'lavoce': LaVoce,
|
||||
'lavoce400': LaVoce400,
|
||||
'fdmresdisc': FDMResDisc,
|
||||
}
|
||||
@@ -0,0 +1,974 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import math as m
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
import torchaudio
|
||||
|
||||
from utils.spec import gen_filterbank
|
||||
|
||||
# auxiliary functions
|
||||
|
||||
def remove_all_weight_norms(module):
|
||||
for m in module.modules():
|
||||
if hasattr(m, 'weight_v'):
|
||||
nn.utils.remove_weight_norm(m)
|
||||
|
||||
|
||||
def create_smoothing_kernel(h, w, gamma=1.5):
|
||||
|
||||
ch = h / 2 - 0.5
|
||||
cw = w / 2 - 0.5
|
||||
|
||||
sh = gamma * ch
|
||||
sw = gamma * cw
|
||||
|
||||
vx = ((torch.arange(h) - ch) / sh) ** 2
|
||||
vy = ((torch.arange(w) - cw) / sw) ** 2
|
||||
vals = vx.view(-1, 1) + vy.view(1, -1)
|
||||
kernel = torch.exp(- vals)
|
||||
kernel = kernel / kernel.sum()
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def create_kernel(h, w, sh, sw):
|
||||
# proto kernel gives disjoint partition of 1
|
||||
proto_kernel = torch.ones((sh, sw))
|
||||
|
||||
# create smoothing kernel eta
|
||||
h_eta, w_eta = h - sh + 1, w - sw + 1
|
||||
assert h_eta > 0 and w_eta > 0
|
||||
eta = create_smoothing_kernel(h_eta, w_eta).view(1, 1, h_eta, w_eta)
|
||||
|
||||
kernel0 = F.pad(proto_kernel, [w_eta - 1, w_eta - 1, h_eta - 1, h_eta - 1]).unsqueeze(0).unsqueeze(0)
|
||||
kernel = F.conv2d(kernel0, eta)
|
||||
|
||||
return kernel
|
||||
|
||||
# positional embeddings
|
||||
class FrequencyPositionalEmbedding(nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
N = x.size(2)
|
||||
args = torch.arange(0, N, dtype=x.dtype, device=x.device) * torch.pi * 2 / N
|
||||
cos = torch.cos(args).reshape(1, 1, -1, 1)
|
||||
sin = torch.sin(args).reshape(1, 1, -1, 1)
|
||||
zeros = torch.zeros_like(x[:, 0:1, :, :])
|
||||
|
||||
y = torch.cat((x, zeros + sin, zeros + cos), dim=1)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class PositionalEmbedding2D(nn.Module):
|
||||
def __init__(self, d=5):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.d = d
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
N = x.size(2)
|
||||
M = x.size(3)
|
||||
|
||||
h_args = torch.arange(0, N, dtype=x.dtype, device=x.device).reshape(1, 1, -1, 1)
|
||||
w_args = torch.arange(0, M, dtype=x.dtype, device=x.device).reshape(1, 1, 1, -1)
|
||||
coeffs = (10000 ** (-2 * torch.arange(0, self.d, dtype=x.dtype, device=x.device) / self.d)).reshape(1, -1, 1, 1)
|
||||
|
||||
h_sin = torch.sin(coeffs * h_args)
|
||||
h_cos = torch.sin(coeffs * h_args)
|
||||
w_sin = torch.sin(coeffs * w_args)
|
||||
w_cos = torch.sin(coeffs * w_args)
|
||||
|
||||
zeros = torch.zeros_like(x[:, 0:1, :, :])
|
||||
|
||||
y = torch.cat((x, zeros + h_sin, zeros + h_cos, zeros + w_sin, zeros + w_cos), dim=1)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
# spectral discriminator base class
|
||||
class SpecDiscriminatorBase(nn.Module):
|
||||
RECEPTIVE_FIELD_MAX_WIDTH=10000
|
||||
def __init__(self,
|
||||
layers,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7000],
|
||||
noise_gain=1e-3,
|
||||
fmap_start_index=0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.resolution = resolution
|
||||
self.fs = fs
|
||||
self.noise_gain = noise_gain
|
||||
self.fmap_start_index = fmap_start_index
|
||||
|
||||
if fmap_start_index >= len(layers):
|
||||
raise ValueError(f'fmap_start_index is larger than number of layers')
|
||||
|
||||
# filter bank for noise shaping
|
||||
n_fft = resolution[0]
|
||||
|
||||
self.filterbank = nn.Parameter(
|
||||
gen_filterbank(n_fft // 2, fs, keep_size=True),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
# roi bins
|
||||
f_step = fs / n_fft
|
||||
self.start_bin = int(m.ceil(freq_roi[0] / f_step - 0.01))
|
||||
self.stop_bin = min(int(m.floor(freq_roi[1] / f_step + 0.01)), n_fft//2 + 1)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# determine receptive field size, offsets and strides
|
||||
|
||||
hw = 1000
|
||||
while True:
|
||||
x = torch.zeros((1, hw, hw))
|
||||
with torch.no_grad():
|
||||
y = self.run_layer_stack(x)[-1]
|
||||
|
||||
pos0 = [y.size(-2) // 2, y.size(-1) // 2]
|
||||
pos1 = [t + 1 for t in pos0]
|
||||
|
||||
hs0, ws0 = self._receptive_field((hw, hw), pos0)
|
||||
hs1, ws1 = self._receptive_field((hw, hw), pos1)
|
||||
|
||||
h0 = hs0[1] - hs0[0] + 1
|
||||
h1 = hs1[1] - hs1[0] + 1
|
||||
w0 = ws0[1] - ws0[0] + 1
|
||||
w1 = ws1[1] - ws1[0] + 1
|
||||
|
||||
if h0 != h1 or w0 != w1:
|
||||
hw = 2 * hw
|
||||
else:
|
||||
|
||||
# strides
|
||||
sh = hs1[0] - hs0[0]
|
||||
sw = ws1[0] - ws0[0]
|
||||
|
||||
if sh == 0 or sw == 0: continue
|
||||
|
||||
# offsets
|
||||
oh = hs0[0] - sh * pos0[0]
|
||||
ow = ws0[0] - sw * pos0[1]
|
||||
|
||||
# overlap factor
|
||||
overlap = w0 / sw + h0 / sh
|
||||
|
||||
#print(f"{w0=} {h0=} {sw=} {sh=} {overlap=}")
|
||||
self.receptive_field_params = {'width': [sw, ow, w0], 'height': [sh, oh, h0], 'overlap': overlap}
|
||||
|
||||
break
|
||||
|
||||
if hw > self.RECEPTIVE_FIELD_MAX_WIDTH:
|
||||
print("warning: exceeded max size while trying to determine receptive field")
|
||||
|
||||
# create transposed convolutional kernel
|
||||
#self.tconv_kernel = nn.Parameter(create_kernel(h0, w0, sw, sw), requires_grad=False)
|
||||
|
||||
def run_layer_stack(self, spec):
|
||||
|
||||
output = []
|
||||
|
||||
x = spec.unsqueeze(1)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
output.append(x)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, x):
|
||||
""" returns array with feature maps and final score at index -1 """
|
||||
|
||||
output = []
|
||||
|
||||
x = self.spectrogram(x)
|
||||
|
||||
output = self.run_layer_stack(x)
|
||||
|
||||
return output[self.fmap_start_index:]
|
||||
|
||||
def receptive_field(self, output_pos):
|
||||
|
||||
if self.receptive_field_params is not None:
|
||||
s, o, h = self.receptive_field_params['height']
|
||||
h_min = output_pos[0] * s + o + self.start_bin
|
||||
h_max = h_min + h
|
||||
h_min = max(h_min, self.start_bin)
|
||||
h_max = min(h_max, self.stop_bin)
|
||||
|
||||
s, o, w = self.receptive_field_params['width']
|
||||
w_min = output_pos[1] * s + o
|
||||
w_max = w_min + w
|
||||
|
||||
return (h_min, h_max), (w_min, w_max)
|
||||
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
def _receptive_field(self, input_dims, output_pos):
|
||||
""" determines receptive field probabilistically via autograd (slow) """
|
||||
|
||||
x = torch.randn((1,) + input_dims, requires_grad=True)
|
||||
|
||||
# run input through layers
|
||||
y = self.run_layer_stack(x)[-1]
|
||||
b, c, h, w = y.shape
|
||||
|
||||
if output_pos[0] >= h or output_pos[1] >= w:
|
||||
raise ValueError("position out of range")
|
||||
|
||||
mask = torch.zeros((b, c, h, w))
|
||||
mask[0, 0, output_pos[0], output_pos[1]] = 1
|
||||
|
||||
(mask * y).sum().backward()
|
||||
|
||||
hs, ws = torch.nonzero(x.grad[0], as_tuple=True)
|
||||
|
||||
h_min, h_max = hs.min().item(), hs.max().item()
|
||||
w_min, w_max = ws.min().item(), ws.max().item()
|
||||
|
||||
return [h_min, h_max], [w_min, w_max]
|
||||
|
||||
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
|
||||
def spectrogram(self, x):
|
||||
n_fft, hop_length, win_length = self.resolution
|
||||
x = x.squeeze(1)
|
||||
window = getattr(torch, 'hann_window')(win_length).to(x.device)
|
||||
|
||||
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
|
||||
window=window, return_complex=True) #[B, F, T]
|
||||
x = torch.abs(x)
|
||||
|
||||
# noise floor following spectral envelope
|
||||
smoothed_x = torch.matmul(self.filterbank, x)
|
||||
noise = torch.randn_like(x) * smoothed_x * self.noise_gain
|
||||
x = x + noise
|
||||
|
||||
# frequency ROI
|
||||
x = x[:, self.start_bin : self.stop_bin + 1, ...]
|
||||
|
||||
return torchaudio.functional.amplitude_to_DB(x,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)#torch.sqrt(x)
|
||||
|
||||
def grad_map(self, x):
|
||||
self.zero_grad()
|
||||
|
||||
n_fft, hop_length, win_length = self.resolution
|
||||
|
||||
window = getattr(torch, 'hann_window')(win_length).to(x.device)
|
||||
|
||||
y = torch.stft(x.squeeze(1), n_fft=n_fft, hop_length=hop_length, win_length=win_length,
|
||||
window=window, return_complex=True) #[B, F, T]
|
||||
y = torch.abs(y)
|
||||
|
||||
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
|
||||
|
||||
specgram.requires_grad = True
|
||||
specgram.retain_grad()
|
||||
|
||||
if specgram.grad is not None:
|
||||
specgram.grad.zero_()
|
||||
|
||||
y = specgram[:, self.start_bin : self.stop_bin + 1, ...]
|
||||
|
||||
scores = self.run_layer_stack(y)[-1]
|
||||
|
||||
loss = torch.mean((1 - scores) ** 2)
|
||||
loss.backward()
|
||||
|
||||
return specgram.data[0], torch.abs(specgram.grad)[0]
|
||||
|
||||
def relevance_map(self, x):
|
||||
|
||||
n_fft, hop_length, win_length = self.resolution
|
||||
y = x.view(-1)
|
||||
window = getattr(torch, 'hann_window')(win_length).to(x.device)
|
||||
|
||||
y = torch.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
|
||||
window=window, return_complex=True) #[B, F, T]
|
||||
y = torch.abs(y)
|
||||
|
||||
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
|
||||
|
||||
|
||||
scores = self.forward(x)[-1]
|
||||
|
||||
sh, _, h = self.receptive_field_params['height']
|
||||
sw, _, w = self.receptive_field_params['width']
|
||||
kernel = create_kernel(h, w, sh, sw).float().to(scores.device)
|
||||
with torch.no_grad():
|
||||
pad_w = (w + sw - 1) // sw
|
||||
pad_h = (h + sh - 1) // sh
|
||||
padded_scores = F.pad(scores, (pad_w, pad_w, pad_h, pad_h), mode='replicate')
|
||||
# CAVE: padding should be derived from offsets
|
||||
rv = F.conv_transpose2d(padded_scores, kernel, bias=None, stride=(sh, sw), padding=(h//2, w//2))
|
||||
rv = rv[..., pad_h * sh : - pad_h * sh, pad_w * sw : -pad_w * sw]
|
||||
|
||||
relevance = torch.zeros_like(specgram)
|
||||
relevance[..., self.start_bin : self.start_bin + rv.size(-2), : rv.size(-1)] = rv
|
||||
|
||||
|
||||
return specgram, relevance
|
||||
|
||||
|
||||
def lrp(self, x, eps=1e-9, label='both', threshold=0.5, low=None, high=None, verbose=False):
|
||||
""" layer-wise relevance propagation (https://git.tu-berlin.de/gmontavon/lrp-tutorial) """
|
||||
|
||||
# ToDo: this code is highly unsafe as it assumes that layers are nn.Sequential with suitable activations
|
||||
|
||||
def newconv2d(layer,g):
|
||||
|
||||
new_layer = nn.Conv2d(layer.in_channels,
|
||||
layer.out_channels,
|
||||
layer.kernel_size,
|
||||
stride=layer.stride,
|
||||
padding=layer.padding,
|
||||
dilation=layer.dilation,
|
||||
groups=layer.groups)
|
||||
|
||||
try: new_layer.weight = nn.Parameter(g(layer.weight.data.clone()))
|
||||
except AttributeError: pass
|
||||
|
||||
try: new_layer.bias = nn.Parameter(g(layer.bias.data.clone()))
|
||||
except AttributeError: pass
|
||||
|
||||
return new_layer
|
||||
|
||||
bounds = {
|
||||
64: [-85.82449722290039, 2.1755014657974243],
|
||||
128: [-84.49211349487305, 3.5078893899917607],
|
||||
256: [-80.33127822875977, 7.6687201976776125],
|
||||
512: [-73.79328079223633, 14.20672025680542],
|
||||
1024: [-67.59239501953125, 20.40760498046875],
|
||||
2048: [-62.31902580261231, 25.680974197387698],
|
||||
}
|
||||
|
||||
nfft = self.resolution[0]
|
||||
if low is None: low = bounds[nfft][0]
|
||||
if high is None: high = bounds[nfft][1]
|
||||
|
||||
remove_all_weight_norms(self)
|
||||
|
||||
for p in self.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad.zero_()
|
||||
|
||||
num_layers = len(self.layers)
|
||||
X = self.spectrogram(x). detach()
|
||||
|
||||
|
||||
# forward pass
|
||||
A = [X.unsqueeze(1)] + [None] * len(self.layers)
|
||||
|
||||
for i in range(num_layers - 1):
|
||||
A[i + 1] = self.layers[i](A[i])
|
||||
|
||||
# initial relevance is last layer without activation
|
||||
r = A[-2]
|
||||
last_layer_rs = [r]
|
||||
layer = self.layers[-1]
|
||||
for sublayer in list(layer)[:-1]:
|
||||
r = sublayer(r)
|
||||
last_layer_rs.append(r)
|
||||
|
||||
|
||||
mask = torch.zeros_like(r)
|
||||
mask.requires_grad_(False)
|
||||
if verbose:
|
||||
print(r.min(), r.max())
|
||||
if label in {'both', 'fake'}:
|
||||
mask[r < -threshold] = 1
|
||||
if label in {'both', 'real'}:
|
||||
mask[r > threshold] = 1
|
||||
r = r * mask
|
||||
|
||||
# backward pass
|
||||
R = [None] * num_layers + [r]
|
||||
|
||||
for l in range(1, num_layers)[::-1]:
|
||||
A[l] = (A[l]).data.requires_grad_(True)
|
||||
|
||||
layer = nn.Sequential(*(list(self.layers[l])[:-1]))
|
||||
z = layer(A[l]) + eps
|
||||
s = (R[l+1] / z).data
|
||||
(z*s).sum().backward()
|
||||
c = A[l].grad
|
||||
R[l] = (A[l] * c).data
|
||||
|
||||
# first layer
|
||||
A[0] = (A[0].data).requires_grad_(True)
|
||||
|
||||
Xl = (torch.zeros_like(A[0].data) + low).requires_grad_(True)
|
||||
Xh = (torch.zeros_like(A[0].data) + high).requires_grad_(True)
|
||||
|
||||
if len(list(self.layers)) > 2:
|
||||
# unsafe way to check for embedding layer
|
||||
embed = list(self.layers[0])[0]
|
||||
conv = list(self.layers[0])[1]
|
||||
|
||||
layer = nn.Sequential(embed, conv)
|
||||
layerl = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(min=0)))
|
||||
layerh = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(max=0)))
|
||||
|
||||
else:
|
||||
layer = list(self.layers[0])[0]
|
||||
layerl = newconv2d(layer, lambda p: p.clamp(min=0))
|
||||
layerh = newconv2d(layer, lambda p: p.clamp(max=0))
|
||||
|
||||
|
||||
z = layer(A[0])
|
||||
z -= layerl(Xl) + layerh(Xh)
|
||||
s = (R[1] / z).data
|
||||
(z * s).sum().backward()
|
||||
c, cp, cm = A[0].grad, Xl.grad, Xh.grad
|
||||
|
||||
R[0] = (A[0] * c + Xl * cp + Xh * cm)
|
||||
#R[0] = (A[0] * c).data
|
||||
|
||||
return X, R[0].mean(dim=1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def create_3x3_conv_plan(num_layers : int,
|
||||
f_stretch : int,
|
||||
f_down : int,
|
||||
t_stretch : int,
|
||||
t_down : int
|
||||
):
|
||||
|
||||
|
||||
""" creates a stride, dilation, padding plan for a 2d conv network
|
||||
|
||||
Args:
|
||||
num_layers (int): number of layers
|
||||
f_stretch (int): log_2 of stretching factor along frequency axis
|
||||
f_down (int): log_2 of downsampling factor along frequency axis
|
||||
t_stretch (int): log_2 of stretching factor along time axis
|
||||
t_down (int): log_2 of downsampling factor along time axis
|
||||
|
||||
Returns:
|
||||
list(list(tuple)): list containing entries [(stride_t, stride_f), (dilation_t, dilation_f), (padding_t, padding_f)]
|
||||
"""
|
||||
|
||||
assert num_layers > 0 and t_stretch >= 0 and t_down >= 0 and f_stretch >= 0 and f_down >= 0
|
||||
assert f_stretch < num_layers and t_stretch < num_layers
|
||||
|
||||
def process_dimension(n_layers, stretch, down):
|
||||
|
||||
stack_layers = n_layers - 1
|
||||
|
||||
stride_layers = min(min(down, stretch) , stack_layers)
|
||||
dilation_layers = max(min(stack_layers - stride_layers - 1, stretch - stride_layers), 0)
|
||||
final_stride = 2 ** (max(down - stride_layers, 0))
|
||||
|
||||
final_dilation = 1
|
||||
if stride_layers < stack_layers and stretch - stride_layers - dilation_layers > 0:
|
||||
final_dilation = 2
|
||||
|
||||
strides, dilations, paddings = [], [], []
|
||||
processed_layers = 0
|
||||
current_dilation = 1
|
||||
|
||||
for _ in range(stride_layers):
|
||||
# increase receptive field and downsample via stride = 2
|
||||
strides.append(2)
|
||||
dilations.append(1)
|
||||
paddings.append(1)
|
||||
processed_layers += 1
|
||||
|
||||
if processed_layers < stack_layers:
|
||||
strides.append(1)
|
||||
dilations.append(1)
|
||||
paddings.append(1)
|
||||
processed_layers += 1
|
||||
|
||||
for _ in range(dilation_layers):
|
||||
# increase receptive field via dilation = 2
|
||||
strides.append(1)
|
||||
current_dilation *= 2
|
||||
dilations.append(current_dilation)
|
||||
paddings.append(current_dilation)
|
||||
processed_layers += 1
|
||||
|
||||
while processed_layers < n_layers - 1:
|
||||
# fill up with std layers
|
||||
strides.append(1)
|
||||
dilations.append(current_dilation)
|
||||
paddings.append(current_dilation)
|
||||
processed_layers += 1
|
||||
|
||||
# final layer
|
||||
strides.append(final_stride)
|
||||
current_dilation * final_dilation
|
||||
dilations.append(current_dilation)
|
||||
paddings.append(current_dilation)
|
||||
processed_layers += 1
|
||||
|
||||
assert processed_layers == n_layers
|
||||
|
||||
return strides, dilations, paddings
|
||||
|
||||
t_strides, t_dilations, t_paddings = process_dimension(num_layers, t_stretch, t_down)
|
||||
f_strides, f_dilations, f_paddings = process_dimension(num_layers, f_stretch, f_down)
|
||||
|
||||
plan = []
|
||||
|
||||
for i in range(num_layers):
|
||||
plan.append([
|
||||
(f_strides[i], t_strides[i]),
|
||||
(f_dilations[i], t_dilations[i]),
|
||||
(f_paddings[i], t_paddings[i]),
|
||||
])
|
||||
|
||||
return plan
|
||||
|
||||
|
||||
class DiscriminatorExperimental(SpecDiscriminatorBase):
|
||||
|
||||
def __init__(self,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
num_channels=16,
|
||||
max_channels=512,
|
||||
num_layers=5,
|
||||
use_spectral_norm=False):
|
||||
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
|
||||
self.num_channels = num_channels
|
||||
self.num_channels_max = max_channels
|
||||
self.num_layers = num_layers
|
||||
|
||||
layers = []
|
||||
stride = (2, 1)
|
||||
padding= (1, 1)
|
||||
in_channels = 1 + 2
|
||||
out_channels = self.num_channels
|
||||
for _ in range(self.num_layers):
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
)
|
||||
in_channels = out_channels + 2
|
||||
out_channels = min(2 * out_channels, self.num_channels_max)
|
||||
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||
|
||||
# bias biases
|
||||
bias_val = 0.1
|
||||
with torch.no_grad():
|
||||
for name, weight in self.named_parameters():
|
||||
if 'bias' in name:
|
||||
weight = weight + bias_val
|
||||
|
||||
|
||||
configs = {
|
||||
'f_down': {
|
||||
'stretch' : {
|
||||
64 : (0, 0),
|
||||
128: (1, 0),
|
||||
256: (2, 0),
|
||||
512: (3, 0),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
},
|
||||
'down' : {
|
||||
64 : (0, 0),
|
||||
128: (1, 0),
|
||||
256: (2, 0),
|
||||
512: (3, 0),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
}
|
||||
},
|
||||
'ft_down': {
|
||||
'stretch' : {
|
||||
64 : (0, 4),
|
||||
128: (1, 3),
|
||||
256: (2, 2),
|
||||
512: (3, 1),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
},
|
||||
'down' : {
|
||||
64 : (0, 4),
|
||||
128: (1, 3),
|
||||
256: (2, 2),
|
||||
512: (3, 1),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
}
|
||||
},
|
||||
'dilated': {
|
||||
'stretch' : {
|
||||
64 : (0, 4),
|
||||
128: (1, 3),
|
||||
256: (2, 2),
|
||||
512: (3, 1),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
},
|
||||
'down' : {
|
||||
64 : (0, 0),
|
||||
128: (0, 0),
|
||||
256: (0, 0),
|
||||
512: (0, 0),
|
||||
1024: (0, 0),
|
||||
2048: (0, 0)
|
||||
}
|
||||
},
|
||||
'mixed': {
|
||||
'stretch' : {
|
||||
64 : (0, 4),
|
||||
128: (1, 3),
|
||||
256: (2, 2),
|
||||
512: (3, 1),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
},
|
||||
'down' : {
|
||||
64 : (0, 0),
|
||||
128: (1, 0),
|
||||
256: (2, 0),
|
||||
512: (3, 0),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DiscriminatorMagFree(SpecDiscriminatorBase):
|
||||
|
||||
def __init__(self,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
num_channels=16,
|
||||
max_channels=256,
|
||||
num_layers=5,
|
||||
use_spectral_norm=False,
|
||||
design=None):
|
||||
|
||||
if design is None:
|
||||
raise ValueError('error: arch required in DiscriminatorMagFree')
|
||||
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
|
||||
stretch = configs[design]['stretch'][resolution[0]]
|
||||
down = configs[design]['down'][resolution[0]]
|
||||
|
||||
self.num_channels = num_channels
|
||||
self.num_channels_max = max_channels
|
||||
self.num_layers = num_layers
|
||||
self.stretch = stretch
|
||||
self.down = down
|
||||
|
||||
layers = []
|
||||
plan = create_3x3_conv_plan(num_layers + 1, stretch[0], down[0], stretch[1], down[1])
|
||||
in_channels = 1 + 2
|
||||
out_channels = self.num_channels
|
||||
for i in range(self.num_layers):
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=plan[i][0], dilation=plan[i][1], padding=plan[i][2])),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
)
|
||||
in_channels = out_channels + 2
|
||||
# product over strides
|
||||
channel_factor = plan[i][0][0] * plan[i][0][1]
|
||||
out_channels = min(channel_factor * out_channels, self.num_channels_max)
|
||||
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, 1, (3, 3), stride=plan[-1][0], dilation=plan[-1][1], padding=plan[-1][2])),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
# for layer in layers:
|
||||
# print(layer)
|
||||
|
||||
# print("end\n\n")
|
||||
|
||||
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||
|
||||
# bias biases
|
||||
bias_val = 0.1
|
||||
with torch.no_grad():
|
||||
for name, weight in self.named_parameters():
|
||||
if 'bias' in name:
|
||||
weight = weight + bias_val
|
||||
|
||||
class DiscriminatorMagFreqPosition(SpecDiscriminatorBase):
|
||||
|
||||
def __init__(self,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
num_channels=16,
|
||||
max_channels=512,
|
||||
num_layers=5,
|
||||
use_spectral_norm=False):
|
||||
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
|
||||
self.num_channels = num_channels
|
||||
self.num_channels_max = max_channels
|
||||
self.num_layers = num_layers
|
||||
|
||||
layers = []
|
||||
stride = (2, 1)
|
||||
padding= (1, 1)
|
||||
in_channels = 1 + 2
|
||||
out_channels = self.num_channels
|
||||
for _ in range(self.num_layers):
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
)
|
||||
in_channels = out_channels + 2
|
||||
out_channels = min(2 * out_channels, self.num_channels_max)
|
||||
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||
|
||||
|
||||
|
||||
class DiscriminatorMag2dPositional(SpecDiscriminatorBase):
|
||||
|
||||
def __init__(self,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
num_channels=16,
|
||||
max_channels=512,
|
||||
num_layers=5,
|
||||
d=5,
|
||||
use_spectral_norm=False):
|
||||
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.resolution = resolution
|
||||
self.num_channels = num_channels
|
||||
self.num_channels_max = max_channels
|
||||
self.num_layers = num_layers
|
||||
self.d = d
|
||||
embedding_dim = 4 * d
|
||||
|
||||
|
||||
layers = []
|
||||
stride = (2, 2)
|
||||
padding= (1, 1)
|
||||
in_channels = 1 + embedding_dim
|
||||
out_channels = self.num_channels
|
||||
for _ in range(self.num_layers):
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
PositionalEmbedding2D(d),
|
||||
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
)
|
||||
in_channels = out_channels + embedding_dim
|
||||
out_channels = min(2 * out_channels, self.num_channels_max)
|
||||
|
||||
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
PositionalEmbedding2D(),
|
||||
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||
|
||||
|
||||
|
||||
class DiscriminatorMag(SpecDiscriminatorBase):
|
||||
def __init__(self,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
num_channels=32,
|
||||
num_layers=5,
|
||||
use_spectral_norm=False):
|
||||
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
|
||||
self.num_channels = num_channels
|
||||
self.num_layers = num_layers
|
||||
|
||||
layers = []
|
||||
stride = (1, 1)
|
||||
padding= (1, 1)
|
||||
in_channels = 1
|
||||
out_channels = self.num_channels
|
||||
for _ in range(self.num_layers):
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
)
|
||||
in_channels = out_channels
|
||||
|
||||
layers.append(norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)))
|
||||
|
||||
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||
|
||||
|
||||
discriminators = {
|
||||
'mag': DiscriminatorMag,
|
||||
'freqpos': DiscriminatorMagFreqPosition,
|
||||
'2dpos': DiscriminatorMag2dPositional,
|
||||
'experimental': DiscriminatorExperimental,
|
||||
'free': DiscriminatorMagFree
|
||||
}
|
||||
|
||||
class TFDMultiResolutionDiscriminator(torch.nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes_16k=[64, 128, 256, 512, 1024, 2048],
|
||||
architecture='mag',
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
use_spectral_norm=False,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
fft_sizes = [int(round(fft_size_16k * fs / 16000)) for fft_size_16k in fft_sizes_16k]
|
||||
|
||||
resolutions = [[n_fft, n_fft // 4, n_fft] for n_fft in fft_sizes]
|
||||
|
||||
|
||||
Disc = discriminators[architecture]
|
||||
|
||||
discs = [Disc(resolutions[i], fs=fs, freq_roi=freq_roi, noise_gain=noise_gain, use_spectral_norm=use_spectral_norm, **kwargs) for i in range(len(resolutions))]
|
||||
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y):
|
||||
outputs = []
|
||||
|
||||
for disc in self.discriminators:
|
||||
outputs.append(disc(y))
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FWGAN_disc_wrapper(nn.Module):
|
||||
def __init__(self, disc):
|
||||
super().__init__()
|
||||
|
||||
self.disc = disc
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
|
||||
out_real = self.disc(y)
|
||||
out_fake = self.disc(y_hat)
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for y_real, y_fake in zip(out_real, out_fake):
|
||||
y_d_rs.append(y_real[-1])
|
||||
y_d_gs.append(y_fake[-1])
|
||||
fmap_rs.append(y_real[:-1])
|
||||
fmap_gs.append(y_fake[:-1])
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
190
managed_components/78__esp-opus/dnn/torch/osce/models/lace.py
Normal file
190
managed_components/78__esp-opus/dnn/torch/osce/models/lace.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
|
||||
from models.nns_base import NNSBase
|
||||
from models.silk_feature_net_pl import SilkFeatureNetPL
|
||||
from models.silk_feature_net import SilkFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
import sys
|
||||
sys.path.append('../dnntools')
|
||||
|
||||
from dnntools.sparsification import create_sparsifier
|
||||
|
||||
|
||||
class LACE(NNSBase):
|
||||
""" Linear-Adaptive Coding Enhancer """
|
||||
FRAME_SIZE=80
|
||||
|
||||
def __init__(self,
|
||||
num_features=47,
|
||||
pitch_embedding_dim=64,
|
||||
cond_dim=256,
|
||||
pitch_max=257,
|
||||
kernel_size=15,
|
||||
preemph=0.85,
|
||||
skip=91,
|
||||
comb_gain_limit_db=-6,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
conv_gain_limits_db=[-6, 6],
|
||||
numbits_range=[50, 650],
|
||||
numbits_embedding_dim=8,
|
||||
hidden_feature_dim=64,
|
||||
partial_lookahead=True,
|
||||
norm_p=2,
|
||||
softquant=False,
|
||||
sparsify=False,
|
||||
sparsification_schedule=[10000, 30000, 100],
|
||||
sparsification_density=0.5,
|
||||
apply_weight_norm=False):
|
||||
|
||||
super().__init__(skip=skip, preemph=preemph)
|
||||
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
self.pitch_embedding_dim = pitch_embedding_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.preemph = preemph
|
||||
self.skip = skip
|
||||
self.numbits_range = numbits_range
|
||||
self.numbits_embedding_dim = numbits_embedding_dim
|
||||
self.hidden_feature_dim = hidden_feature_dim
|
||||
self.partial_lookahead = partial_lookahead
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
# numbits embedding
|
||||
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
|
||||
|
||||
# feature net
|
||||
if partial_lookahead:
|
||||
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
|
||||
else:
|
||||
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
|
||||
|
||||
# comb filters
|
||||
left_pad = self.kernel_size // 2
|
||||
right_pad = self.kernel_size - 1 - left_pad
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# spectral shaping
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
if sparsify:
|
||||
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
frame_rate = rate / self.FRAME_SIZE
|
||||
|
||||
# feature net
|
||||
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
|
||||
af_flops = self.af1.flop_count(rate)
|
||||
|
||||
if verbose:
|
||||
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
|
||||
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||
|
||||
return feature_net_flops + comb_flops + af_flops
|
||||
|
||||
def forward(self, x, features, periods, numbits, debug=False):
|
||||
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
|
||||
|
||||
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
|
||||
y = self.cf1(x, cf, periods, debug=debug)
|
||||
|
||||
y = self.cf2(y, cf, periods, debug=debug)
|
||||
|
||||
y = self.af1(y, cf, debug=debug)
|
||||
|
||||
return y
|
||||
|
||||
def get_impulse_responses(self, features, periods, numbits):
|
||||
""" generates impoulse responses on frame centers (input without batch dimension) """
|
||||
|
||||
num_frames = features.size(0)
|
||||
batch_size = 32
|
||||
max_len = 2 * (self.pitch_max + self.kernel_size) + 10
|
||||
|
||||
# spread out some pulses
|
||||
x = np.zeros((batch_size, 1, num_frames * self.FRAME_SIZE))
|
||||
for b in range(batch_size):
|
||||
x[b, :, self.FRAME_SIZE // 2 + b * self.FRAME_SIZE :: batch_size * self.FRAME_SIZE] = 1
|
||||
|
||||
# prepare input
|
||||
x = torch.from_numpy(x).float().to(features.device)
|
||||
features = torch.repeat_interleave(features.unsqueeze(0), batch_size, 0)
|
||||
periods = torch.repeat_interleave(periods.unsqueeze(0), batch_size, 0)
|
||||
numbits = torch.repeat_interleave(numbits.unsqueeze(0), batch_size, 0)
|
||||
|
||||
# run network
|
||||
with torch.no_grad():
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
|
||||
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
y = self.cf1(x, cf, periods, debug=False)
|
||||
y = self.cf2(y, cf, periods, debug=False)
|
||||
y = self.af1(y, cf, debug=False)
|
||||
|
||||
# collect responses
|
||||
y = y.detach().squeeze().cpu().numpy()
|
||||
cut_frames = (max_len + self.FRAME_SIZE - 1) // self.FRAME_SIZE
|
||||
num_responses = num_frames - cut_frames
|
||||
responses = np.zeros((num_responses, max_len))
|
||||
|
||||
for i in range(num_responses):
|
||||
b = i % batch_size
|
||||
start = self.FRAME_SIZE // 2 + i * self.FRAME_SIZE
|
||||
stop = start + max_len
|
||||
|
||||
responses[i, :] = y[b, start:stop]
|
||||
|
||||
return responses
|
||||
274
managed_components/78__esp-opus/dnn/torch/osce/models/lavoce.py
Normal file
274
managed_components/78__esp-opus/dnn/torch/osce/models/lavoce.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
from utils.layers.td_shaper import TDShaper
|
||||
from utils.layers.noise_shaper import NoiseShaper
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
from utils.endoscopy import write_data
|
||||
|
||||
from models.nns_base import NNSBase
|
||||
from models.lpcnet_feature_net import LPCNetFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
def print_channels(y, prefix="", name="", rate=16000):
|
||||
num_channels = y.size(1)
|
||||
for i in range(num_channels):
|
||||
channel_name = f"{prefix}_c{i:02d}"
|
||||
if len(name) > 0: channel_name += "_" + name
|
||||
ch = y[0,i,:].detach().cpu().numpy()
|
||||
ch = ((2**14) * ch / np.max(ch)).astype(np.int16)
|
||||
write_data(channel_name, ch, rate)
|
||||
|
||||
|
||||
|
||||
class LaVoce(nn.Module):
|
||||
""" Linear-Adaptive VOCodEr """
|
||||
FEATURE_FRAME_SIZE=160
|
||||
FRAME_SIZE=80
|
||||
|
||||
def __init__(self,
|
||||
num_features=20,
|
||||
pitch_embedding_dim=64,
|
||||
cond_dim=256,
|
||||
pitch_max=300,
|
||||
kernel_size=15,
|
||||
preemph=0.85,
|
||||
comb_gain_limit_db=-6,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
conv_gain_limits_db=[-6, 6],
|
||||
norm_p=2,
|
||||
avg_pool_k=4,
|
||||
pulses=False,
|
||||
innovate1=True,
|
||||
innovate2=False,
|
||||
innovate3=False,
|
||||
ftrans_k=2):
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
self.pitch_embedding_dim = pitch_embedding_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.preemph = preemph
|
||||
self.pulses = pulses
|
||||
self.ftrans_k = ftrans_k
|
||||
|
||||
assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
|
||||
self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
# feature net
|
||||
self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
|
||||
|
||||
# noise shaper
|
||||
self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
|
||||
|
||||
# comb filters
|
||||
left_pad = self.kernel_size // 2
|
||||
right_pad = self.kernel_size - 1 - left_pad
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
|
||||
self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# spectral shaping
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# non-linear transforms
|
||||
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate1)
|
||||
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate2)
|
||||
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate3)
|
||||
|
||||
# combinators
|
||||
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# feature transforms
|
||||
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
|
||||
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
|
||||
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
|
||||
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
|
||||
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
|
||||
|
||||
|
||||
def create_phase_signals(self, periods):
|
||||
|
||||
batch_size = periods.size(0)
|
||||
progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
|
||||
progression = torch.repeat_interleave(progression, batch_size, 0)
|
||||
|
||||
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
|
||||
chunks = []
|
||||
for sframe in range(periods.size(1)):
|
||||
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
|
||||
|
||||
if self.pulses:
|
||||
alpha = torch.cos(f).view(batch_size, 1, 1)
|
||||
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
|
||||
pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
|
||||
|
||||
chunk = torch.cat((pulse_a, pulse_b), dim = 1)
|
||||
else:
|
||||
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
|
||||
chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
|
||||
|
||||
phase0 = phase0 + self.FRAME_SIZE * f
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
phase_signals = torch.cat(chunks, dim=-1)
|
||||
|
||||
return phase_signals
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
frame_rate = rate / self.FRAME_SIZE
|
||||
|
||||
# feature net
|
||||
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
|
||||
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
|
||||
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
|
||||
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
|
||||
|
||||
if verbose:
|
||||
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
|
||||
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
|
||||
|
||||
return feature_net_flops + comb_flops + af_flops + feature_flops
|
||||
|
||||
def feature_transform(self, f, layer):
|
||||
f = f.permute(0, 2, 1)
|
||||
f = F.pad(f, [self.ftrans_k - 1, 0])
|
||||
f = torch.tanh(layer(f))
|
||||
return f.permute(0, 2, 1)
|
||||
|
||||
def forward(self, features, periods, debug=False):
|
||||
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
|
||||
full_features = torch.cat((features, pitch_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
|
||||
# upsample periods
|
||||
periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
|
||||
|
||||
# pre-net
|
||||
ref_phase = torch.tanh(self.create_phase_signals(periods))
|
||||
if debug: print_channels(ref_phase, prefix="lavoce_01", name="pulse")
|
||||
x = self.af_prescale(ref_phase, cf)
|
||||
noise = self.noise_shaper(cf)
|
||||
if debug: print_channels(torch.cat((x, noise), dim=1), prefix="lavoce_02", name="inputs")
|
||||
y = self.af_mix(torch.cat((x, noise), dim=1), cf)
|
||||
if debug: print_channels(y, prefix="lavoce_03", name="postselect1")
|
||||
|
||||
# temporal shaping + innovating
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape1(y[:, 1:2, :], cf)
|
||||
if debug: print_channels(y2, prefix="lavoce_04", name="postshape1")
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af2(y, cf, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_05", name="postselect2")
|
||||
cf = self.feature_transform(cf, self.post_af2)
|
||||
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape2(y[:, 1:2, :], cf)
|
||||
if debug: print_channels(y2, prefix="lavoce_06", name="postshape2")
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af3(y, cf, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_07", name="postmix1")
|
||||
cf = self.feature_transform(cf, self.post_af3)
|
||||
|
||||
# spectral shaping
|
||||
y = self.cf1(y, cf, periods, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_08", name="postcomb1")
|
||||
cf = self.feature_transform(cf, self.post_cf1)
|
||||
|
||||
y = self.cf2(y, cf, periods, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_09", name="postcomb2")
|
||||
cf = self.feature_transform(cf, self.post_cf2)
|
||||
|
||||
y = self.af1(y, cf, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_10", name="postselect3")
|
||||
cf = self.feature_transform(cf, self.post_af1)
|
||||
|
||||
# final temporal env adjustment
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape3(y[:, 1:2, :], cf)
|
||||
if debug: print_channels(y2, prefix="lavoce_11", name="postshape3")
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af4(y, cf, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_12", name="postmix2")
|
||||
|
||||
return y
|
||||
|
||||
def process(self, features, periods, debug=False):
|
||||
|
||||
self.eval()
|
||||
device = next(iter(self.parameters())).device
|
||||
with torch.no_grad():
|
||||
|
||||
# run model
|
||||
f = features.unsqueeze(0).to(device)
|
||||
p = periods.unsqueeze(0).to(device)
|
||||
|
||||
y = self.forward(f, p, debug=debug).squeeze()
|
||||
|
||||
# deemphasis
|
||||
if self.preemph > 0:
|
||||
for i in range(len(y) - 1):
|
||||
y[i + 1] += self.preemph * y[i]
|
||||
|
||||
# clip to valid range
|
||||
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
from utils.layers.td_shaper import TDShaper
|
||||
from utils.layers.noise_shaper import NoiseShaper
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
from utils.endoscopy import write_data
|
||||
|
||||
from models.nns_base import NNSBase
|
||||
from models.lpcnet_feature_net import LPCNetFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
class LaVoce400(nn.Module):
|
||||
""" Linear-Adaptive VOCodEr """
|
||||
FEATURE_FRAME_SIZE=160
|
||||
FRAME_SIZE=40
|
||||
|
||||
def __init__(self,
|
||||
num_features=20,
|
||||
pitch_embedding_dim=64,
|
||||
cond_dim=256,
|
||||
pitch_max=300,
|
||||
kernel_size=15,
|
||||
preemph=0.85,
|
||||
comb_gain_limit_db=-6,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
conv_gain_limits_db=[-6, 6],
|
||||
norm_p=2,
|
||||
avg_pool_k=4,
|
||||
pulses=False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
self.pitch_embedding_dim = pitch_embedding_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.preemph = preemph
|
||||
self.pulses = pulses
|
||||
|
||||
assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
|
||||
self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
# feature net
|
||||
self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
|
||||
|
||||
# noise shaper
|
||||
self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
|
||||
|
||||
# comb filters
|
||||
left_pad = self.kernel_size // 2
|
||||
right_pad = self.kernel_size - 1 - left_pad
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
|
||||
self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# spectral shaping
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# non-linear transforms
|
||||
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=True)
|
||||
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
|
||||
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
|
||||
|
||||
# combinators
|
||||
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# feature transforms
|
||||
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
|
||||
|
||||
def create_phase_signals(self, periods):
|
||||
|
||||
batch_size = periods.size(0)
|
||||
progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
|
||||
progression = torch.repeat_interleave(progression, batch_size, 0)
|
||||
|
||||
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
|
||||
chunks = []
|
||||
for sframe in range(periods.size(1)):
|
||||
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
|
||||
|
||||
if self.pulses:
|
||||
alpha = torch.cos(f).view(batch_size, 1, 1)
|
||||
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
|
||||
pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
|
||||
|
||||
chunk = torch.cat((pulse_a, pulse_b), dim = 1)
|
||||
else:
|
||||
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
|
||||
chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
|
||||
|
||||
phase0 = phase0 + self.FRAME_SIZE * f
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
phase_signals = torch.cat(chunks, dim=-1)
|
||||
|
||||
return phase_signals
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
frame_rate = rate / self.FRAME_SIZE
|
||||
|
||||
# feature net
|
||||
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
|
||||
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
|
||||
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
|
||||
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
|
||||
|
||||
if verbose:
|
||||
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
|
||||
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
|
||||
|
||||
return feature_net_flops + comb_flops + af_flops + feature_flops
|
||||
|
||||
def feature_transform(self, f, layer):
|
||||
f = f.permute(0, 2, 1)
|
||||
f = F.pad(f, [1, 0])
|
||||
f = torch.tanh(layer(f))
|
||||
return f.permute(0, 2, 1)
|
||||
|
||||
def forward(self, features, periods, debug=False):
|
||||
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
|
||||
full_features = torch.cat((features, pitch_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
|
||||
# upsample periods
|
||||
periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
|
||||
|
||||
# pre-net
|
||||
ref_phase = torch.tanh(self.create_phase_signals(periods))
|
||||
x = self.af_prescale(ref_phase, cf)
|
||||
noise = self.noise_shaper(cf)
|
||||
y = self.af_mix(torch.cat((x, noise), dim=1), cf)
|
||||
|
||||
if debug:
|
||||
ch0 = y[0,0,:].detach().cpu().numpy()
|
||||
ch1 = y[0,1,:].detach().cpu().numpy()
|
||||
ch0 = (2**15 * ch0 / np.max(ch0)).astype(np.int16)
|
||||
ch1 = (2**15 * ch1 / np.max(ch1)).astype(np.int16)
|
||||
write_data('prior_channel0', ch0, 16000)
|
||||
write_data('prior_channel1', ch1, 16000)
|
||||
|
||||
# temporal shaping + innovating
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape1(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af2(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af2)
|
||||
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape2(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af3(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af3)
|
||||
|
||||
# spectral shaping
|
||||
y = self.cf1(y, cf, periods, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_cf1)
|
||||
|
||||
y = self.cf2(y, cf, periods, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_cf2)
|
||||
|
||||
y = self.af1(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af1)
|
||||
|
||||
# final temporal env adjustment
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape3(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af4(y, cf, debug=debug)
|
||||
|
||||
return y
|
||||
|
||||
def process(self, features, periods, debug=False):
|
||||
|
||||
self.eval()
|
||||
device = next(iter(self.parameters())).device
|
||||
with torch.no_grad():
|
||||
|
||||
# run model
|
||||
f = features.unsqueeze(0).to(device)
|
||||
p = periods.unsqueeze(0).to(device)
|
||||
|
||||
y = self.forward(f, p, debug=debug).squeeze()
|
||||
|
||||
# deemphasis
|
||||
if self.preemph > 0:
|
||||
for i in range(len(y) - 1):
|
||||
y[i + 1] += self.preemph * y[i]
|
||||
|
||||
# clip to valid range
|
||||
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
class LPCNetFeatureNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
feature_dim=84,
|
||||
num_channels=256,
|
||||
upsamp_factor=2,
|
||||
lookahead=True):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.num_channels = num_channels
|
||||
self.upsamp_factor = upsamp_factor
|
||||
self.lookahead = lookahead
|
||||
|
||||
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
|
||||
self.conv2 = nn.Conv1d(num_channels, num_channels, 3)
|
||||
|
||||
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
|
||||
|
||||
self.tconv = nn.ConvTranspose1d(num_channels, num_channels, upsamp_factor, upsamp_factor)
|
||||
|
||||
def flop_count(self, rate=100):
|
||||
count = 0
|
||||
for conv in self.conv1, self.conv2, self.tconv:
|
||||
count += _conv1d_flop_count(conv, rate)
|
||||
|
||||
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def forward(self, features, state=None):
|
||||
""" features shape: (batch_size, num_frames, feature_dim) """
|
||||
|
||||
batch_size = features.size(0)
|
||||
|
||||
if state is None:
|
||||
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
|
||||
|
||||
|
||||
features = features.permute(0, 2, 1)
|
||||
if self.lookahead:
|
||||
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
|
||||
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
|
||||
else:
|
||||
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
|
||||
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
|
||||
|
||||
c = torch.tanh(self.tconv(c))
|
||||
|
||||
c = c.permute(0, 2, 1)
|
||||
|
||||
c, _ = self.gru(c, state)
|
||||
|
||||
return c
|
||||
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class NNSBase(nn.Module):
|
||||
|
||||
def __init__(self, skip=91, preemph=0.85):
|
||||
super().__init__()
|
||||
|
||||
self.skip = skip
|
||||
self.preemph = preemph
|
||||
|
||||
def process(self, sig, features, periods, numbits, debug=False):
|
||||
|
||||
self.eval()
|
||||
has_numbits = 'numbits' in self.forward.__code__.co_varnames
|
||||
device = next(iter(self.parameters())).device
|
||||
with torch.no_grad():
|
||||
|
||||
# run model
|
||||
x = sig.view(1, 1, -1).to(device)
|
||||
f = features.unsqueeze(0).to(device)
|
||||
p = periods.unsqueeze(0).to(device)
|
||||
n = numbits.unsqueeze(0).to(device)
|
||||
|
||||
if has_numbits:
|
||||
y = self.forward(x, f, p, n, debug=debug).squeeze()
|
||||
else:
|
||||
y = self.forward(x, f, p, debug=debug).squeeze()
|
||||
|
||||
# deemphasis
|
||||
if self.preemph > 0:
|
||||
for i in range(len(y) - 1):
|
||||
y[i + 1] += self.preemph * y[i]
|
||||
|
||||
# delay compensation
|
||||
y = torch.cat((y[self.skip:], torch.zeros(self.skip, dtype=y.dtype, device=y.device)))
|
||||
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
|
||||
|
||||
return out
|
||||
218
managed_components/78__esp-opus/dnn/torch/osce/models/no_lace.py
Normal file
218
managed_components/78__esp-opus/dnn/torch/osce/models/no_lace.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numbers
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
from utils.layers.td_shaper import TDShaper
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
from models.nns_base import NNSBase
|
||||
from models.silk_feature_net_pl import SilkFeatureNetPL
|
||||
from models.silk_feature_net import SilkFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
import sys
|
||||
sys.path.append('../dnntools')
|
||||
from dnntools.quantization import soft_quant
|
||||
from dnntools.sparsification import create_sparsifier, mark_for_sparsification
|
||||
|
||||
class NoLACE(NNSBase):
|
||||
""" Non-Linear Adaptive Coding Enhancer """
|
||||
FRAME_SIZE=80
|
||||
|
||||
def __init__(self,
|
||||
num_features=47,
|
||||
pitch_embedding_dim=64,
|
||||
cond_dim=256,
|
||||
pitch_max=257,
|
||||
kernel_size=15,
|
||||
preemph=0.85,
|
||||
skip=91,
|
||||
comb_gain_limit_db=-6,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
conv_gain_limits_db=[-6, 6],
|
||||
numbits_range=[50, 650],
|
||||
numbits_embedding_dim=8,
|
||||
hidden_feature_dim=64,
|
||||
partial_lookahead=True,
|
||||
norm_p=2,
|
||||
avg_pool_k=4,
|
||||
pool_after=False,
|
||||
softquant=False,
|
||||
sparsify=False,
|
||||
sparsification_schedule=[100, 1000, 100],
|
||||
sparsification_density=0.5,
|
||||
apply_weight_norm=False):
|
||||
|
||||
super().__init__(skip=skip, preemph=preemph)
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
self.pitch_embedding_dim = pitch_embedding_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.preemph = preemph
|
||||
self.skip = skip
|
||||
self.numbits_range = numbits_range
|
||||
self.numbits_embedding_dim = numbits_embedding_dim
|
||||
self.hidden_feature_dim = hidden_feature_dim
|
||||
self.partial_lookahead = partial_lookahead
|
||||
|
||||
if isinstance(sparsification_density, numbers.Number):
|
||||
sparsification_density = 10 * [sparsification_density]
|
||||
|
||||
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
# numbits embedding
|
||||
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
|
||||
|
||||
# feature net
|
||||
if partial_lookahead:
|
||||
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
|
||||
else:
|
||||
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
|
||||
|
||||
# comb filters
|
||||
left_pad = self.kernel_size // 2
|
||||
right_pad = self.kernel_size - 1 - left_pad
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# spectral shaping
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# non-linear transforms
|
||||
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# combinators
|
||||
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# feature transforms
|
||||
self.post_cf1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_cf2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_af1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_af2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_af3 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
|
||||
if softquant:
|
||||
self.post_cf1 = soft_quant(self.post_cf1)
|
||||
self.post_cf2 = soft_quant(self.post_cf2)
|
||||
self.post_af1 = soft_quant(self.post_af1)
|
||||
self.post_af2 = soft_quant(self.post_af2)
|
||||
self.post_af3 = soft_quant(self.post_af3)
|
||||
|
||||
|
||||
if sparsify:
|
||||
mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4]))
|
||||
mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4]))
|
||||
mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4]))
|
||||
mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4]))
|
||||
mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4]))
|
||||
|
||||
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
frame_rate = rate / self.FRAME_SIZE
|
||||
|
||||
# feature net
|
||||
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
|
||||
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate)
|
||||
shape_flops = self.tdshape1.flop_count(rate) + self.tdshape2.flop_count(rate) + self.tdshape3.flop_count(rate)
|
||||
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
|
||||
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
|
||||
|
||||
if verbose:
|
||||
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
|
||||
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
|
||||
|
||||
return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops
|
||||
|
||||
def feature_transform(self, f, layer):
|
||||
f0 = f.permute(0, 2, 1)
|
||||
f = F.pad(f0, [1, 0])
|
||||
f = torch.tanh(layer(f))
|
||||
return f.permute(0, 2, 1)
|
||||
|
||||
def forward(self, x, features, periods, numbits, debug=False):
|
||||
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
|
||||
|
||||
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
|
||||
y = self.cf1(x, cf, periods, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_cf1)
|
||||
|
||||
y = self.cf2(y, cf, periods, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_cf2)
|
||||
|
||||
y = self.af1(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af1)
|
||||
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape1(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af2(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af2)
|
||||
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape2(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af3(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af3)
|
||||
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape3(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af4(y, cf, debug=debug)
|
||||
|
||||
return y
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import math as m
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ScaleEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
min_val,
|
||||
max_val,
|
||||
logscale=False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if min_val >= max_val:
|
||||
raise ValueError('min_val must be smaller than max_val')
|
||||
|
||||
if min_val <= 0 and logscale:
|
||||
raise ValueError('min_val must be positive when logscale is true')
|
||||
|
||||
self.dim = dim
|
||||
self.logscale = logscale
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
|
||||
if logscale:
|
||||
self.min_val = m.log(self.min_val)
|
||||
self.max_val = m.log(self.max_val)
|
||||
|
||||
|
||||
self.offset = (self.min_val + self.max_val) / 2
|
||||
self.scale_factors = nn.Parameter(
|
||||
torch.arange(1, dim+1, dtype=torch.float32) * torch.pi / (self.max_val - self.min_val)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.logscale: x = torch.log(x)
|
||||
x = torch.clip(x, self.min_val, self.max_val) - self.offset
|
||||
return torch.sin(x.unsqueeze(-1) * self.scale_factors - 0.5)
|
||||
@@ -0,0 +1,179 @@
|
||||
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.silk_upsampler import SilkUpsampler
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
from utils.layers.td_shaper import TDShaper
|
||||
from utils.layers.deemph import Deemph
|
||||
from utils.misc import freeze_model
|
||||
|
||||
from models.nns_base import NNSBase
|
||||
from models.silk_feature_net_pl import SilkFeatureNetPL
|
||||
from models.silk_feature_net import SilkFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
|
||||
|
||||
class ShapeUp48(NNSBase):
|
||||
FRAME_SIZE16k=80
|
||||
|
||||
def __init__(self,
|
||||
num_features=47,
|
||||
pitch_embedding_dim=64,
|
||||
cond_dim=256,
|
||||
pitch_max=257,
|
||||
kernel_size=15,
|
||||
preemph=0.85,
|
||||
skip=288,
|
||||
conv_gain_limits_db=[-6, 6],
|
||||
numbits_range=[50, 650],
|
||||
numbits_embedding_dim=8,
|
||||
hidden_feature_dim=64,
|
||||
partial_lookahead=True,
|
||||
norm_p=2,
|
||||
target_fs=48000,
|
||||
noise_amplitude=0,
|
||||
prenet=None,
|
||||
avg_pool_k=4):
|
||||
|
||||
super().__init__(skip=skip, preemph=preemph)
|
||||
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
self.pitch_embedding_dim = pitch_embedding_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.preemph = preemph
|
||||
self.skip = skip
|
||||
self.numbits_range = numbits_range
|
||||
self.numbits_embedding_dim = numbits_embedding_dim
|
||||
self.hidden_feature_dim = hidden_feature_dim
|
||||
self.partial_lookahead = partial_lookahead
|
||||
self.frame_size48 = int(self.FRAME_SIZE16k * target_fs / 16000 + .1)
|
||||
self.frame_size32 = self.FRAME_SIZE16k * 2
|
||||
self.noise_amplitude = noise_amplitude
|
||||
self.prenet = prenet
|
||||
|
||||
# freeze prenet if given
|
||||
if prenet is not None:
|
||||
freeze_model(self.prenet)
|
||||
try:
|
||||
self.deemph = Deemph(prenet.preemph)
|
||||
except:
|
||||
print("[warning] prenet model is expected to have preemph attribute")
|
||||
self.deemph = Deemph(0)
|
||||
|
||||
|
||||
|
||||
# upsampler
|
||||
self.upsampler = SilkUpsampler()
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
# numbits embedding
|
||||
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
|
||||
|
||||
# feature net
|
||||
if partial_lookahead:
|
||||
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
|
||||
else:
|
||||
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
|
||||
|
||||
# non-linear transforms
|
||||
self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k)
|
||||
self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k)
|
||||
|
||||
# spectral shaping
|
||||
self.af_noise = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=[-30, 0], norm_p=norm_p)
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af2 = LimitedAdaptiveConv1d(3, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
frame_rate = rate / self.FRAME_SIZE16k
|
||||
|
||||
# feature net
|
||||
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(2 * rate) + self.af3.flop_count(3 * rate)
|
||||
|
||||
if verbose:
|
||||
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||
|
||||
return feature_net_flops + af_flops
|
||||
|
||||
def forward(self, x, features, periods, numbits, debug=False):
|
||||
|
||||
if self.prenet is not None:
|
||||
with torch.no_grad():
|
||||
x = self.prenet(x, features, periods, numbits)
|
||||
x = self.deemph(x)
|
||||
|
||||
|
||||
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
|
||||
|
||||
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
|
||||
y32 = self.upsampler.hq_2x_up(x)
|
||||
|
||||
noise = self.noise_amplitude * torch.randn_like(y32)
|
||||
noise = self.af_noise(noise, cf)
|
||||
|
||||
y32 = self.af1(y32, cf, debug=debug)
|
||||
|
||||
y32_1 = y32[:, 0:1, :]
|
||||
y32_2 = self.tdshape1(y32[:, 1:2, :], cf)
|
||||
y32 = torch.cat((y32_1, y32_2, noise), dim=1)
|
||||
|
||||
y32 = self.af2(y32, cf, debug=debug)
|
||||
|
||||
y48 = self.upsampler.interpolate_3_2(y32)
|
||||
|
||||
y48_1 = y48[:, 0:1, :]
|
||||
y48_2 = self.tdshape2(y48[:, 1:2, :], cf)
|
||||
y48 = torch.cat((y48_1, y48_2), dim=1)
|
||||
|
||||
y48 = self.af3(y48, cf, debug=debug)
|
||||
|
||||
return y48
|
||||
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
class SilkFeatureNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
feature_dim=47,
|
||||
num_channels=256,
|
||||
lookahead=False):
|
||||
|
||||
super(SilkFeatureNet, self).__init__()
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.num_channels = num_channels
|
||||
self.lookahead = lookahead
|
||||
|
||||
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
|
||||
self.conv2 = nn.Conv1d(num_channels, num_channels, 3, dilation=2)
|
||||
|
||||
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
|
||||
|
||||
def flop_count(self, rate=200):
|
||||
count = 0
|
||||
for conv in self.conv1, self.conv2:
|
||||
count += _conv1d_flop_count(conv, rate)
|
||||
|
||||
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def forward(self, features, state=None):
|
||||
""" features shape: (batch_size, num_frames, feature_dim) """
|
||||
|
||||
batch_size = features.size(0)
|
||||
|
||||
if state is None:
|
||||
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
|
||||
|
||||
|
||||
features = features.permute(0, 2, 1)
|
||||
if self.lookahead:
|
||||
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
|
||||
c = torch.tanh(self.conv2(F.pad(c, [2, 2])))
|
||||
else:
|
||||
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
|
||||
c = torch.tanh(self.conv2(F.pad(c, [4, 0])))
|
||||
|
||||
c = c.permute(0, 2, 1)
|
||||
|
||||
c, _ = self.gru(c, state)
|
||||
|
||||
return c
|
||||
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
import sys
|
||||
sys.path.append('../dnntools')
|
||||
import numbers
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
from dnntools.quantization.softquant import soft_quant
|
||||
from dnntools.sparsification import mark_for_sparsification
|
||||
|
||||
class SilkFeatureNetPL(nn.Module):
|
||||
""" feature net with partial lookahead """
|
||||
def __init__(self,
|
||||
feature_dim=47,
|
||||
num_channels=256,
|
||||
hidden_feature_dim=64,
|
||||
softquant=False,
|
||||
sparsify=True,
|
||||
sparsification_density=0.5,
|
||||
apply_weight_norm=False):
|
||||
|
||||
super(SilkFeatureNetPL, self).__init__()
|
||||
|
||||
if isinstance(sparsification_density, numbers.Number):
|
||||
sparsification_density = 4 * [sparsification_density]
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.num_channels = num_channels
|
||||
self.hidden_feature_dim = hidden_feature_dim
|
||||
|
||||
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1))
|
||||
self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2))
|
||||
self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
|
||||
self.gru = norm(norm(nn.GRU(num_channels, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
|
||||
|
||||
if softquant:
|
||||
self.conv2 = soft_quant(self.conv2)
|
||||
self.tconv = soft_quant(self.tconv)
|
||||
self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
|
||||
|
||||
if sparsify:
|
||||
mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4]))
|
||||
mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
|
||||
mark_for_sparsification(
|
||||
self.gru,
|
||||
{
|
||||
'W_ir' : (sparsification_density[2], [8, 4], False),
|
||||
'W_iz' : (sparsification_density[2], [8, 4], False),
|
||||
'W_in' : (sparsification_density[2], [8, 4], False),
|
||||
'W_hr' : (sparsification_density[3], [8, 4], True),
|
||||
'W_hz' : (sparsification_density[3], [8, 4], True),
|
||||
'W_hn' : (sparsification_density[3], [8, 4], True),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def flop_count(self, rate=200):
|
||||
count = 0
|
||||
for conv in self.conv1, self.conv2, self.tconv:
|
||||
count += _conv1d_flop_count(conv, rate)
|
||||
|
||||
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def forward(self, features, state=None):
|
||||
""" features shape: (batch_size, num_frames, feature_dim) """
|
||||
|
||||
batch_size = features.size(0)
|
||||
num_frames = features.size(1)
|
||||
|
||||
if state is None:
|
||||
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
|
||||
|
||||
features = features.permute(0, 2, 1)
|
||||
# dimensionality reduction
|
||||
c = torch.tanh(self.conv1(features))
|
||||
|
||||
# frame accumulation
|
||||
c = c.permute(0, 2, 1)
|
||||
c = c.reshape(batch_size, num_frames // 4, -1).permute(0, 2, 1)
|
||||
c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
|
||||
|
||||
# upsampling
|
||||
c = torch.tanh(self.tconv(c))
|
||||
c = c.permute(0, 2, 1)
|
||||
|
||||
c, _ = self.gru(c, state)
|
||||
|
||||
return c
|
||||
Reference in New Issue
Block a user