add some code

This commit is contained in:
2025-09-05 13:25:11 +08:00
parent 9ff0a99e7a
commit 3cf1229a85
8911 changed files with 2535396 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
from .lace import LACE
from .no_lace import NoLACE
from .lavoce import LaVoce
from .lavoce_400 import LaVoce400
from .fd_discriminator import TFDMultiResolutionDiscriminator as FDMResDisc
model_dict = {
'lace': LACE,
'nolace': NoLACE,
'lavoce': LaVoce,
'lavoce400': LaVoce400,
'fdmresdisc': FDMResDisc,
}

View File

@@ -0,0 +1,974 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import math as m
import copy
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import weight_norm, spectral_norm
import torchaudio
from utils.spec import gen_filterbank
# auxiliary functions
def remove_all_weight_norms(module):
for m in module.modules():
if hasattr(m, 'weight_v'):
nn.utils.remove_weight_norm(m)
def create_smoothing_kernel(h, w, gamma=1.5):
ch = h / 2 - 0.5
cw = w / 2 - 0.5
sh = gamma * ch
sw = gamma * cw
vx = ((torch.arange(h) - ch) / sh) ** 2
vy = ((torch.arange(w) - cw) / sw) ** 2
vals = vx.view(-1, 1) + vy.view(1, -1)
kernel = torch.exp(- vals)
kernel = kernel / kernel.sum()
return kernel
def create_kernel(h, w, sh, sw):
# proto kernel gives disjoint partition of 1
proto_kernel = torch.ones((sh, sw))
# create smoothing kernel eta
h_eta, w_eta = h - sh + 1, w - sw + 1
assert h_eta > 0 and w_eta > 0
eta = create_smoothing_kernel(h_eta, w_eta).view(1, 1, h_eta, w_eta)
kernel0 = F.pad(proto_kernel, [w_eta - 1, w_eta - 1, h_eta - 1, h_eta - 1]).unsqueeze(0).unsqueeze(0)
kernel = F.conv2d(kernel0, eta)
return kernel
# positional embeddings
class FrequencyPositionalEmbedding(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
N = x.size(2)
args = torch.arange(0, N, dtype=x.dtype, device=x.device) * torch.pi * 2 / N
cos = torch.cos(args).reshape(1, 1, -1, 1)
sin = torch.sin(args).reshape(1, 1, -1, 1)
zeros = torch.zeros_like(x[:, 0:1, :, :])
y = torch.cat((x, zeros + sin, zeros + cos), dim=1)
return y
class PositionalEmbedding2D(nn.Module):
def __init__(self, d=5):
super().__init__()
self.d = d
def forward(self, x):
N = x.size(2)
M = x.size(3)
h_args = torch.arange(0, N, dtype=x.dtype, device=x.device).reshape(1, 1, -1, 1)
w_args = torch.arange(0, M, dtype=x.dtype, device=x.device).reshape(1, 1, 1, -1)
coeffs = (10000 ** (-2 * torch.arange(0, self.d, dtype=x.dtype, device=x.device) / self.d)).reshape(1, -1, 1, 1)
h_sin = torch.sin(coeffs * h_args)
h_cos = torch.sin(coeffs * h_args)
w_sin = torch.sin(coeffs * w_args)
w_cos = torch.sin(coeffs * w_args)
zeros = torch.zeros_like(x[:, 0:1, :, :])
y = torch.cat((x, zeros + h_sin, zeros + h_cos, zeros + w_sin, zeros + w_cos), dim=1)
return y
# spectral discriminator base class
class SpecDiscriminatorBase(nn.Module):
RECEPTIVE_FIELD_MAX_WIDTH=10000
def __init__(self,
layers,
resolution,
fs=16000,
freq_roi=[50, 7000],
noise_gain=1e-3,
fmap_start_index=0
):
super().__init__()
self.layers = nn.ModuleList(layers)
self.resolution = resolution
self.fs = fs
self.noise_gain = noise_gain
self.fmap_start_index = fmap_start_index
if fmap_start_index >= len(layers):
raise ValueError(f'fmap_start_index is larger than number of layers')
# filter bank for noise shaping
n_fft = resolution[0]
self.filterbank = nn.Parameter(
gen_filterbank(n_fft // 2, fs, keep_size=True),
requires_grad=False
)
# roi bins
f_step = fs / n_fft
self.start_bin = int(m.ceil(freq_roi[0] / f_step - 0.01))
self.stop_bin = min(int(m.floor(freq_roi[1] / f_step + 0.01)), n_fft//2 + 1)
self.init_weights()
# determine receptive field size, offsets and strides
hw = 1000
while True:
x = torch.zeros((1, hw, hw))
with torch.no_grad():
y = self.run_layer_stack(x)[-1]
pos0 = [y.size(-2) // 2, y.size(-1) // 2]
pos1 = [t + 1 for t in pos0]
hs0, ws0 = self._receptive_field((hw, hw), pos0)
hs1, ws1 = self._receptive_field((hw, hw), pos1)
h0 = hs0[1] - hs0[0] + 1
h1 = hs1[1] - hs1[0] + 1
w0 = ws0[1] - ws0[0] + 1
w1 = ws1[1] - ws1[0] + 1
if h0 != h1 or w0 != w1:
hw = 2 * hw
else:
# strides
sh = hs1[0] - hs0[0]
sw = ws1[0] - ws0[0]
if sh == 0 or sw == 0: continue
# offsets
oh = hs0[0] - sh * pos0[0]
ow = ws0[0] - sw * pos0[1]
# overlap factor
overlap = w0 / sw + h0 / sh
#print(f"{w0=} {h0=} {sw=} {sh=} {overlap=}")
self.receptive_field_params = {'width': [sw, ow, w0], 'height': [sh, oh, h0], 'overlap': overlap}
break
if hw > self.RECEPTIVE_FIELD_MAX_WIDTH:
print("warning: exceeded max size while trying to determine receptive field")
# create transposed convolutional kernel
#self.tconv_kernel = nn.Parameter(create_kernel(h0, w0, sw, sw), requires_grad=False)
def run_layer_stack(self, spec):
output = []
x = spec.unsqueeze(1)
for layer in self.layers:
x = layer(x)
output.append(x)
return output
def forward(self, x):
""" returns array with feature maps and final score at index -1 """
output = []
x = self.spectrogram(x)
output = self.run_layer_stack(x)
return output[self.fmap_start_index:]
def receptive_field(self, output_pos):
if self.receptive_field_params is not None:
s, o, h = self.receptive_field_params['height']
h_min = output_pos[0] * s + o + self.start_bin
h_max = h_min + h
h_min = max(h_min, self.start_bin)
h_max = min(h_max, self.stop_bin)
s, o, w = self.receptive_field_params['width']
w_min = output_pos[1] * s + o
w_max = w_min + w
return (h_min, h_max), (w_min, w_max)
else:
return None, None
def _receptive_field(self, input_dims, output_pos):
""" determines receptive field probabilistically via autograd (slow) """
x = torch.randn((1,) + input_dims, requires_grad=True)
# run input through layers
y = self.run_layer_stack(x)[-1]
b, c, h, w = y.shape
if output_pos[0] >= h or output_pos[1] >= w:
raise ValueError("position out of range")
mask = torch.zeros((b, c, h, w))
mask[0, 0, output_pos[0], output_pos[1]] = 1
(mask * y).sum().backward()
hs, ws = torch.nonzero(x.grad[0], as_tuple=True)
h_min, h_max = hs.min().item(), hs.max().item()
w_min, w_max = ws.min().item(), ws.max().item()
return [h_min, h_max], [w_min, w_max]
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def spectrogram(self, x):
n_fft, hop_length, win_length = self.resolution
x = x.squeeze(1)
window = getattr(torch, 'hann_window')(win_length).to(x.device)
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
window=window, return_complex=True) #[B, F, T]
x = torch.abs(x)
# noise floor following spectral envelope
smoothed_x = torch.matmul(self.filterbank, x)
noise = torch.randn_like(x) * smoothed_x * self.noise_gain
x = x + noise
# frequency ROI
x = x[:, self.start_bin : self.stop_bin + 1, ...]
return torchaudio.functional.amplitude_to_DB(x,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)#torch.sqrt(x)
def grad_map(self, x):
self.zero_grad()
n_fft, hop_length, win_length = self.resolution
window = getattr(torch, 'hann_window')(win_length).to(x.device)
y = torch.stft(x.squeeze(1), n_fft=n_fft, hop_length=hop_length, win_length=win_length,
window=window, return_complex=True) #[B, F, T]
y = torch.abs(y)
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
specgram.requires_grad = True
specgram.retain_grad()
if specgram.grad is not None:
specgram.grad.zero_()
y = specgram[:, self.start_bin : self.stop_bin + 1, ...]
scores = self.run_layer_stack(y)[-1]
loss = torch.mean((1 - scores) ** 2)
loss.backward()
return specgram.data[0], torch.abs(specgram.grad)[0]
def relevance_map(self, x):
n_fft, hop_length, win_length = self.resolution
y = x.view(-1)
window = getattr(torch, 'hann_window')(win_length).to(x.device)
y = torch.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
window=window, return_complex=True) #[B, F, T]
y = torch.abs(y)
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
scores = self.forward(x)[-1]
sh, _, h = self.receptive_field_params['height']
sw, _, w = self.receptive_field_params['width']
kernel = create_kernel(h, w, sh, sw).float().to(scores.device)
with torch.no_grad():
pad_w = (w + sw - 1) // sw
pad_h = (h + sh - 1) // sh
padded_scores = F.pad(scores, (pad_w, pad_w, pad_h, pad_h), mode='replicate')
# CAVE: padding should be derived from offsets
rv = F.conv_transpose2d(padded_scores, kernel, bias=None, stride=(sh, sw), padding=(h//2, w//2))
rv = rv[..., pad_h * sh : - pad_h * sh, pad_w * sw : -pad_w * sw]
relevance = torch.zeros_like(specgram)
relevance[..., self.start_bin : self.start_bin + rv.size(-2), : rv.size(-1)] = rv
return specgram, relevance
def lrp(self, x, eps=1e-9, label='both', threshold=0.5, low=None, high=None, verbose=False):
""" layer-wise relevance propagation (https://git.tu-berlin.de/gmontavon/lrp-tutorial) """
# ToDo: this code is highly unsafe as it assumes that layers are nn.Sequential with suitable activations
def newconv2d(layer,g):
new_layer = nn.Conv2d(layer.in_channels,
layer.out_channels,
layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
dilation=layer.dilation,
groups=layer.groups)
try: new_layer.weight = nn.Parameter(g(layer.weight.data.clone()))
except AttributeError: pass
try: new_layer.bias = nn.Parameter(g(layer.bias.data.clone()))
except AttributeError: pass
return new_layer
bounds = {
64: [-85.82449722290039, 2.1755014657974243],
128: [-84.49211349487305, 3.5078893899917607],
256: [-80.33127822875977, 7.6687201976776125],
512: [-73.79328079223633, 14.20672025680542],
1024: [-67.59239501953125, 20.40760498046875],
2048: [-62.31902580261231, 25.680974197387698],
}
nfft = self.resolution[0]
if low is None: low = bounds[nfft][0]
if high is None: high = bounds[nfft][1]
remove_all_weight_norms(self)
for p in self.parameters():
if p.grad is not None:
p.grad.zero_()
num_layers = len(self.layers)
X = self.spectrogram(x). detach()
# forward pass
A = [X.unsqueeze(1)] + [None] * len(self.layers)
for i in range(num_layers - 1):
A[i + 1] = self.layers[i](A[i])
# initial relevance is last layer without activation
r = A[-2]
last_layer_rs = [r]
layer = self.layers[-1]
for sublayer in list(layer)[:-1]:
r = sublayer(r)
last_layer_rs.append(r)
mask = torch.zeros_like(r)
mask.requires_grad_(False)
if verbose:
print(r.min(), r.max())
if label in {'both', 'fake'}:
mask[r < -threshold] = 1
if label in {'both', 'real'}:
mask[r > threshold] = 1
r = r * mask
# backward pass
R = [None] * num_layers + [r]
for l in range(1, num_layers)[::-1]:
A[l] = (A[l]).data.requires_grad_(True)
layer = nn.Sequential(*(list(self.layers[l])[:-1]))
z = layer(A[l]) + eps
s = (R[l+1] / z).data
(z*s).sum().backward()
c = A[l].grad
R[l] = (A[l] * c).data
# first layer
A[0] = (A[0].data).requires_grad_(True)
Xl = (torch.zeros_like(A[0].data) + low).requires_grad_(True)
Xh = (torch.zeros_like(A[0].data) + high).requires_grad_(True)
if len(list(self.layers)) > 2:
# unsafe way to check for embedding layer
embed = list(self.layers[0])[0]
conv = list(self.layers[0])[1]
layer = nn.Sequential(embed, conv)
layerl = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(min=0)))
layerh = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(max=0)))
else:
layer = list(self.layers[0])[0]
layerl = newconv2d(layer, lambda p: p.clamp(min=0))
layerh = newconv2d(layer, lambda p: p.clamp(max=0))
z = layer(A[0])
z -= layerl(Xl) + layerh(Xh)
s = (R[1] / z).data
(z * s).sum().backward()
c, cp, cm = A[0].grad, Xl.grad, Xh.grad
R[0] = (A[0] * c + Xl * cp + Xh * cm)
#R[0] = (A[0] * c).data
return X, R[0].mean(dim=1)
def create_3x3_conv_plan(num_layers : int,
f_stretch : int,
f_down : int,
t_stretch : int,
t_down : int
):
""" creates a stride, dilation, padding plan for a 2d conv network
Args:
num_layers (int): number of layers
f_stretch (int): log_2 of stretching factor along frequency axis
f_down (int): log_2 of downsampling factor along frequency axis
t_stretch (int): log_2 of stretching factor along time axis
t_down (int): log_2 of downsampling factor along time axis
Returns:
list(list(tuple)): list containing entries [(stride_t, stride_f), (dilation_t, dilation_f), (padding_t, padding_f)]
"""
assert num_layers > 0 and t_stretch >= 0 and t_down >= 0 and f_stretch >= 0 and f_down >= 0
assert f_stretch < num_layers and t_stretch < num_layers
def process_dimension(n_layers, stretch, down):
stack_layers = n_layers - 1
stride_layers = min(min(down, stretch) , stack_layers)
dilation_layers = max(min(stack_layers - stride_layers - 1, stretch - stride_layers), 0)
final_stride = 2 ** (max(down - stride_layers, 0))
final_dilation = 1
if stride_layers < stack_layers and stretch - stride_layers - dilation_layers > 0:
final_dilation = 2
strides, dilations, paddings = [], [], []
processed_layers = 0
current_dilation = 1
for _ in range(stride_layers):
# increase receptive field and downsample via stride = 2
strides.append(2)
dilations.append(1)
paddings.append(1)
processed_layers += 1
if processed_layers < stack_layers:
strides.append(1)
dilations.append(1)
paddings.append(1)
processed_layers += 1
for _ in range(dilation_layers):
# increase receptive field via dilation = 2
strides.append(1)
current_dilation *= 2
dilations.append(current_dilation)
paddings.append(current_dilation)
processed_layers += 1
while processed_layers < n_layers - 1:
# fill up with std layers
strides.append(1)
dilations.append(current_dilation)
paddings.append(current_dilation)
processed_layers += 1
# final layer
strides.append(final_stride)
current_dilation * final_dilation
dilations.append(current_dilation)
paddings.append(current_dilation)
processed_layers += 1
assert processed_layers == n_layers
return strides, dilations, paddings
t_strides, t_dilations, t_paddings = process_dimension(num_layers, t_stretch, t_down)
f_strides, f_dilations, f_paddings = process_dimension(num_layers, f_stretch, f_down)
plan = []
for i in range(num_layers):
plan.append([
(f_strides[i], t_strides[i]),
(f_dilations[i], t_dilations[i]),
(f_paddings[i], t_paddings[i]),
])
return plan
class DiscriminatorExperimental(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=512,
num_layers=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
layers = []
stride = (2, 1)
padding= (1, 1)
in_channels = 1 + 2
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.ReLU(inplace=True)
)
)
in_channels = out_channels + 2
out_channels = min(2 * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)),
nn.Sigmoid()
)
)
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
# bias biases
bias_val = 0.1
with torch.no_grad():
for name, weight in self.named_parameters():
if 'bias' in name:
weight = weight + bias_val
configs = {
'f_down': {
'stretch' : {
64 : (0, 0),
128: (1, 0),
256: (2, 0),
512: (3, 0),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 0),
128: (1, 0),
256: (2, 0),
512: (3, 0),
1024: (4, 0),
2048: (5, 0)
}
},
'ft_down': {
'stretch' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
}
},
'dilated': {
'stretch' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 0),
128: (0, 0),
256: (0, 0),
512: (0, 0),
1024: (0, 0),
2048: (0, 0)
}
},
'mixed': {
'stretch' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 0),
128: (1, 0),
256: (2, 0),
512: (3, 0),
1024: (4, 0),
2048: (5, 0)
}
},
}
class DiscriminatorMagFree(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=256,
num_layers=5,
use_spectral_norm=False,
design=None):
if design is None:
raise ValueError('error: arch required in DiscriminatorMagFree')
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
stretch = configs[design]['stretch'][resolution[0]]
down = configs[design]['down'][resolution[0]]
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
self.stretch = stretch
self.down = down
layers = []
plan = create_3x3_conv_plan(num_layers + 1, stretch[0], down[0], stretch[1], down[1])
in_channels = 1 + 2
out_channels = self.num_channels
for i in range(self.num_layers):
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=plan[i][0], dilation=plan[i][1], padding=plan[i][2])),
nn.ReLU(inplace=True)
)
)
in_channels = out_channels + 2
# product over strides
channel_factor = plan[i][0][0] * plan[i][0][1]
out_channels = min(channel_factor * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), stride=plan[-1][0], dilation=plan[-1][1], padding=plan[-1][2])),
nn.Sigmoid()
)
)
# for layer in layers:
# print(layer)
# print("end\n\n")
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
# bias biases
bias_val = 0.1
with torch.no_grad():
for name, weight in self.named_parameters():
if 'bias' in name:
weight = weight + bias_val
class DiscriminatorMagFreqPosition(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=512,
num_layers=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
layers = []
stride = (2, 1)
padding= (1, 1)
in_channels = 1 + 2
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.LeakyReLU(0.2, inplace=True)
)
)
in_channels = out_channels + 2
out_channels = min(2 * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
)
)
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
class DiscriminatorMag2dPositional(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=512,
num_layers=5,
d=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.resolution = resolution
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
self.d = d
embedding_dim = 4 * d
layers = []
stride = (2, 2)
padding= (1, 1)
in_channels = 1 + embedding_dim
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
PositionalEmbedding2D(d),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.LeakyReLU(0.2, inplace=True)
)
)
in_channels = out_channels + embedding_dim
out_channels = min(2 * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
PositionalEmbedding2D(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
)
)
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
class DiscriminatorMag(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=32,
num_layers=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.num_channels = num_channels
self.num_layers = num_layers
layers = []
stride = (1, 1)
padding= (1, 1)
in_channels = 1
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.LeakyReLU(0.2, inplace=True)
)
)
in_channels = out_channels
layers.append(norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)))
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
discriminators = {
'mag': DiscriminatorMag,
'freqpos': DiscriminatorMagFreqPosition,
'2dpos': DiscriminatorMag2dPositional,
'experimental': DiscriminatorExperimental,
'free': DiscriminatorMagFree
}
class TFDMultiResolutionDiscriminator(torch.nn.Module):
def __init__(self,
fft_sizes_16k=[64, 128, 256, 512, 1024, 2048],
architecture='mag',
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
use_spectral_norm=False,
**kwargs):
super().__init__()
fft_sizes = [int(round(fft_size_16k * fs / 16000)) for fft_size_16k in fft_sizes_16k]
resolutions = [[n_fft, n_fft // 4, n_fft] for n_fft in fft_sizes]
Disc = discriminators[architecture]
discs = [Disc(resolutions[i], fs=fs, freq_roi=freq_roi, noise_gain=noise_gain, use_spectral_norm=use_spectral_norm, **kwargs) for i in range(len(resolutions))]
self.discriminators = nn.ModuleList(discs)
def forward(self, y):
outputs = []
for disc in self.discriminators:
outputs.append(disc(y))
return outputs
class FWGAN_disc_wrapper(nn.Module):
def __init__(self, disc):
super().__init__()
self.disc = disc
def forward(self, y, y_hat):
out_real = self.disc(y)
out_fake = self.disc(y_hat)
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for y_real, y_fake in zip(out_real, out_fake):
y_d_rs.append(y_real[-1])
y_d_gs.append(y_fake[-1])
fmap_rs.append(y_real[:-1])
fmap_gs.append(y_fake[:-1])
return y_d_rs, y_d_gs, fmap_rs, fmap_gs

View File

@@ -0,0 +1,190 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from models.nns_base import NNSBase
from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
import sys
sys.path.append('../dnntools')
from dnntools.sparsification import create_sparsifier
class LACE(NNSBase):
""" Linear-Adaptive Coding Enhancer """
FRAME_SIZE=80
def __init__(self,
num_features=47,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=257,
kernel_size=15,
preemph=0.85,
skip=91,
comb_gain_limit_db=-6,
global_gain_limits_db=[-6, 6],
conv_gain_limits_db=[-6, 6],
numbits_range=[50, 650],
numbits_embedding_dim=8,
hidden_feature_dim=64,
partial_lookahead=True,
norm_p=2,
softquant=False,
sparsify=False,
sparsification_schedule=[10000, 30000, 100],
sparsification_density=0.5,
apply_weight_norm=False):
super().__init__(skip=skip, preemph=preemph)
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.skip = skip
self.numbits_range = numbits_range
self.numbits_embedding_dim = numbits_embedding_dim
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# numbits embedding
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
# feature net
if partial_lookahead:
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
if sparsify:
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
af_flops = self.af1.flop_count(rate)
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
return feature_net_flops + comb_flops + af_flops
def forward(self, x, features, periods, numbits, debug=False):
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
cf = self.feature_net(full_features)
y = self.cf1(x, cf, periods, debug=debug)
y = self.cf2(y, cf, periods, debug=debug)
y = self.af1(y, cf, debug=debug)
return y
def get_impulse_responses(self, features, periods, numbits):
""" generates impoulse responses on frame centers (input without batch dimension) """
num_frames = features.size(0)
batch_size = 32
max_len = 2 * (self.pitch_max + self.kernel_size) + 10
# spread out some pulses
x = np.zeros((batch_size, 1, num_frames * self.FRAME_SIZE))
for b in range(batch_size):
x[b, :, self.FRAME_SIZE // 2 + b * self.FRAME_SIZE :: batch_size * self.FRAME_SIZE] = 1
# prepare input
x = torch.from_numpy(x).float().to(features.device)
features = torch.repeat_interleave(features.unsqueeze(0), batch_size, 0)
periods = torch.repeat_interleave(periods.unsqueeze(0), batch_size, 0)
numbits = torch.repeat_interleave(numbits.unsqueeze(0), batch_size, 0)
# run network
with torch.no_grad():
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
cf = self.feature_net(full_features)
y = self.cf1(x, cf, periods, debug=False)
y = self.cf2(y, cf, periods, debug=False)
y = self.af1(y, cf, debug=False)
# collect responses
y = y.detach().squeeze().cpu().numpy()
cut_frames = (max_len + self.FRAME_SIZE - 1) // self.FRAME_SIZE
num_responses = num_frames - cut_frames
responses = np.zeros((num_responses, max_len))
for i in range(num_responses):
b = i % batch_size
start = self.FRAME_SIZE // 2 + i * self.FRAME_SIZE
stop = start + max_len
responses[i, :] = y[b, start:stop]
return responses

View File

@@ -0,0 +1,274 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.layers.noise_shaper import NoiseShaper
from utils.complexity import _conv1d_flop_count
from utils.endoscopy import write_data
from models.nns_base import NNSBase
from models.lpcnet_feature_net import LPCNetFeatureNet
from .scale_embedding import ScaleEmbedding
def print_channels(y, prefix="", name="", rate=16000):
num_channels = y.size(1)
for i in range(num_channels):
channel_name = f"{prefix}_c{i:02d}"
if len(name) > 0: channel_name += "_" + name
ch = y[0,i,:].detach().cpu().numpy()
ch = ((2**14) * ch / np.max(ch)).astype(np.int16)
write_data(channel_name, ch, rate)
class LaVoce(nn.Module):
""" Linear-Adaptive VOCodEr """
FEATURE_FRAME_SIZE=160
FRAME_SIZE=80
def __init__(self,
num_features=20,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=300,
kernel_size=15,
preemph=0.85,
comb_gain_limit_db=-6,
global_gain_limits_db=[-6, 6],
conv_gain_limits_db=[-6, 6],
norm_p=2,
avg_pool_k=4,
pulses=False,
innovate1=True,
innovate2=False,
innovate3=False,
ftrans_k=2):
super().__init__()
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.pulses = pulses
self.ftrans_k = ftrans_k
assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# feature net
self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
# noise shaper
self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# non-linear transforms
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate1)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate2)
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate3)
# combinators
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# feature transforms
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
def create_phase_signals(self, periods):
batch_size = periods.size(0)
progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
progression = torch.repeat_interleave(progression, batch_size, 0)
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
chunks = []
for sframe in range(periods.size(1)):
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
if self.pulses:
alpha = torch.cos(f).view(batch_size, 1, 1)
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
chunk = torch.cat((pulse_a, pulse_b), dim = 1)
else:
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
phase0 = phase0 + self.FRAME_SIZE * f
chunks.append(chunk)
phase_signals = torch.cat(chunks, dim=-1)
return phase_signals
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
return feature_net_flops + comb_flops + af_flops + feature_flops
def feature_transform(self, f, layer):
f = f.permute(0, 2, 1)
f = F.pad(f, [self.ftrans_k - 1, 0])
f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, features, periods, debug=False):
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
full_features = torch.cat((features, pitch_embedding), dim=-1)
cf = self.feature_net(full_features)
# upsample periods
periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
# pre-net
ref_phase = torch.tanh(self.create_phase_signals(periods))
if debug: print_channels(ref_phase, prefix="lavoce_01", name="pulse")
x = self.af_prescale(ref_phase, cf)
noise = self.noise_shaper(cf)
if debug: print_channels(torch.cat((x, noise), dim=1), prefix="lavoce_02", name="inputs")
y = self.af_mix(torch.cat((x, noise), dim=1), cf)
if debug: print_channels(y, prefix="lavoce_03", name="postselect1")
# temporal shaping + innovating
y1 = y[:, 0:1, :]
y2 = self.tdshape1(y[:, 1:2, :], cf)
if debug: print_channels(y2, prefix="lavoce_04", name="postshape1")
y = torch.cat((y1, y2), dim=1)
y = self.af2(y, cf, debug=debug)
if debug: print_channels(y, prefix="lavoce_05", name="postselect2")
cf = self.feature_transform(cf, self.post_af2)
y1 = y[:, 0:1, :]
y2 = self.tdshape2(y[:, 1:2, :], cf)
if debug: print_channels(y2, prefix="lavoce_06", name="postshape2")
y = torch.cat((y1, y2), dim=1)
y = self.af3(y, cf, debug=debug)
if debug: print_channels(y, prefix="lavoce_07", name="postmix1")
cf = self.feature_transform(cf, self.post_af3)
# spectral shaping
y = self.cf1(y, cf, periods, debug=debug)
if debug: print_channels(y, prefix="lavoce_08", name="postcomb1")
cf = self.feature_transform(cf, self.post_cf1)
y = self.cf2(y, cf, periods, debug=debug)
if debug: print_channels(y, prefix="lavoce_09", name="postcomb2")
cf = self.feature_transform(cf, self.post_cf2)
y = self.af1(y, cf, debug=debug)
if debug: print_channels(y, prefix="lavoce_10", name="postselect3")
cf = self.feature_transform(cf, self.post_af1)
# final temporal env adjustment
y1 = y[:, 0:1, :]
y2 = self.tdshape3(y[:, 1:2, :], cf)
if debug: print_channels(y2, prefix="lavoce_11", name="postshape3")
y = torch.cat((y1, y2), dim=1)
y = self.af4(y, cf, debug=debug)
if debug: print_channels(y, prefix="lavoce_12", name="postmix2")
return y
def process(self, features, periods, debug=False):
self.eval()
device = next(iter(self.parameters())).device
with torch.no_grad():
# run model
f = features.unsqueeze(0).to(device)
p = periods.unsqueeze(0).to(device)
y = self.forward(f, p, debug=debug).squeeze()
# deemphasis
if self.preemph > 0:
for i in range(len(y) - 1):
y[i + 1] += self.preemph * y[i]
# clip to valid range
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
return out

View File

@@ -0,0 +1,254 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.layers.noise_shaper import NoiseShaper
from utils.complexity import _conv1d_flop_count
from utils.endoscopy import write_data
from models.nns_base import NNSBase
from models.lpcnet_feature_net import LPCNetFeatureNet
from .scale_embedding import ScaleEmbedding
class LaVoce400(nn.Module):
""" Linear-Adaptive VOCodEr """
FEATURE_FRAME_SIZE=160
FRAME_SIZE=40
def __init__(self,
num_features=20,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=300,
kernel_size=15,
preemph=0.85,
comb_gain_limit_db=-6,
global_gain_limits_db=[-6, 6],
conv_gain_limits_db=[-6, 6],
norm_p=2,
avg_pool_k=4,
pulses=False):
super().__init__()
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.pulses = pulses
assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# feature net
self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
# noise shaper
self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# non-linear transforms
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=True)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
# combinators
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# feature transforms
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2)
def create_phase_signals(self, periods):
batch_size = periods.size(0)
progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
progression = torch.repeat_interleave(progression, batch_size, 0)
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
chunks = []
for sframe in range(periods.size(1)):
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
if self.pulses:
alpha = torch.cos(f).view(batch_size, 1, 1)
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
chunk = torch.cat((pulse_a, pulse_b), dim = 1)
else:
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
phase0 = phase0 + self.FRAME_SIZE * f
chunks.append(chunk)
phase_signals = torch.cat(chunks, dim=-1)
return phase_signals
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
return feature_net_flops + comb_flops + af_flops + feature_flops
def feature_transform(self, f, layer):
f = f.permute(0, 2, 1)
f = F.pad(f, [1, 0])
f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, features, periods, debug=False):
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
full_features = torch.cat((features, pitch_embedding), dim=-1)
cf = self.feature_net(full_features)
# upsample periods
periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
# pre-net
ref_phase = torch.tanh(self.create_phase_signals(periods))
x = self.af_prescale(ref_phase, cf)
noise = self.noise_shaper(cf)
y = self.af_mix(torch.cat((x, noise), dim=1), cf)
if debug:
ch0 = y[0,0,:].detach().cpu().numpy()
ch1 = y[0,1,:].detach().cpu().numpy()
ch0 = (2**15 * ch0 / np.max(ch0)).astype(np.int16)
ch1 = (2**15 * ch1 / np.max(ch1)).astype(np.int16)
write_data('prior_channel0', ch0, 16000)
write_data('prior_channel1', ch1, 16000)
# temporal shaping + innovating
y1 = y[:, 0:1, :]
y2 = self.tdshape1(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af2(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af2)
y1 = y[:, 0:1, :]
y2 = self.tdshape2(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af3(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af3)
# spectral shaping
y = self.cf1(y, cf, periods, debug=debug)
cf = self.feature_transform(cf, self.post_cf1)
y = self.cf2(y, cf, periods, debug=debug)
cf = self.feature_transform(cf, self.post_cf2)
y = self.af1(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af1)
# final temporal env adjustment
y1 = y[:, 0:1, :]
y2 = self.tdshape3(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af4(y, cf, debug=debug)
return y
def process(self, features, periods, debug=False):
self.eval()
device = next(iter(self.parameters())).device
with torch.no_grad():
# run model
f = features.unsqueeze(0).to(device)
p = periods.unsqueeze(0).to(device)
y = self.forward(f, p, debug=debug).squeeze()
# deemphasis
if self.preemph > 0:
for i in range(len(y) - 1):
y[i + 1] += self.preemph * y[i]
# clip to valid range
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
return out

View File

@@ -0,0 +1,91 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
class LPCNetFeatureNet(nn.Module):
def __init__(self,
feature_dim=84,
num_channels=256,
upsamp_factor=2,
lookahead=True):
super().__init__()
self.feature_dim = feature_dim
self.num_channels = num_channels
self.upsamp_factor = upsamp_factor
self.lookahead = lookahead
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
self.conv2 = nn.Conv1d(num_channels, num_channels, 3)
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
self.tconv = nn.ConvTranspose1d(num_channels, num_channels, upsamp_factor, upsamp_factor)
def flop_count(self, rate=100):
count = 0
for conv in self.conv1, self.conv2, self.tconv:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
return count
def forward(self, features, state=None):
""" features shape: (batch_size, num_frames, feature_dim) """
batch_size = features.size(0)
if state is None:
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
features = features.permute(0, 2, 1)
if self.lookahead:
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
else:
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
c = torch.tanh(self.tconv(c))
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)
return c

View File

@@ -0,0 +1,69 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
class NNSBase(nn.Module):
def __init__(self, skip=91, preemph=0.85):
super().__init__()
self.skip = skip
self.preemph = preemph
def process(self, sig, features, periods, numbits, debug=False):
self.eval()
has_numbits = 'numbits' in self.forward.__code__.co_varnames
device = next(iter(self.parameters())).device
with torch.no_grad():
# run model
x = sig.view(1, 1, -1).to(device)
f = features.unsqueeze(0).to(device)
p = periods.unsqueeze(0).to(device)
n = numbits.unsqueeze(0).to(device)
if has_numbits:
y = self.forward(x, f, p, n, debug=debug).squeeze()
else:
y = self.forward(x, f, p, debug=debug).squeeze()
# deemphasis
if self.preemph > 0:
for i in range(len(y) - 1):
y[i + 1] += self.preemph * y[i]
# delay compensation
y = torch.cat((y[self.skip:], torch.zeros(self.skip, dtype=y.dtype, device=y.device)))
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
return out

View File

@@ -0,0 +1,218 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numbers
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import numpy as np
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.complexity import _conv1d_flop_count
from models.nns_base import NNSBase
from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
import sys
sys.path.append('../dnntools')
from dnntools.quantization import soft_quant
from dnntools.sparsification import create_sparsifier, mark_for_sparsification
class NoLACE(NNSBase):
""" Non-Linear Adaptive Coding Enhancer """
FRAME_SIZE=80
def __init__(self,
num_features=47,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=257,
kernel_size=15,
preemph=0.85,
skip=91,
comb_gain_limit_db=-6,
global_gain_limits_db=[-6, 6],
conv_gain_limits_db=[-6, 6],
numbits_range=[50, 650],
numbits_embedding_dim=8,
hidden_feature_dim=64,
partial_lookahead=True,
norm_p=2,
avg_pool_k=4,
pool_after=False,
softquant=False,
sparsify=False,
sparsification_schedule=[100, 1000, 100],
sparsification_density=0.5,
apply_weight_norm=False):
super().__init__(skip=skip, preemph=preemph)
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.skip = skip
self.numbits_range = numbits_range
self.numbits_embedding_dim = numbits_embedding_dim
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
if isinstance(sparsification_density, numbers.Number):
sparsification_density = 10 * [sparsification_density]
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# numbits embedding
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
# feature net
if partial_lookahead:
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# non-linear transforms
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
# combinators
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# feature transforms
self.post_cf1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_cf2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_af1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_af2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_af3 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
if softquant:
self.post_cf1 = soft_quant(self.post_cf1)
self.post_cf2 = soft_quant(self.post_cf2)
self.post_af1 = soft_quant(self.post_af1)
self.post_af2 = soft_quant(self.post_af2)
self.post_af3 = soft_quant(self.post_af3)
if sparsify:
mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4]))
mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4]))
mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4]))
mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4]))
mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4]))
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate)
shape_flops = self.tdshape1.flop_count(rate) + self.tdshape2.flop_count(rate) + self.tdshape3.flop_count(rate)
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops
def feature_transform(self, f, layer):
f0 = f.permute(0, 2, 1)
f = F.pad(f0, [1, 0])
f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, x, features, periods, numbits, debug=False):
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
cf = self.feature_net(full_features)
y = self.cf1(x, cf, periods, debug=debug)
cf = self.feature_transform(cf, self.post_cf1)
y = self.cf2(y, cf, periods, debug=debug)
cf = self.feature_transform(cf, self.post_cf2)
y = self.af1(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af1)
y1 = y[:, 0:1, :]
y2 = self.tdshape1(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af2(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af2)
y1 = y[:, 0:1, :]
y2 = self.tdshape2(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af3(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af3)
y1 = y[:, 0:1, :]
y2 = self.tdshape3(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af4(y, cf, debug=debug)
return y

View File

@@ -0,0 +1,68 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import math as m
import torch
from torch import nn
class ScaleEmbedding(nn.Module):
def __init__(self,
dim,
min_val,
max_val,
logscale=False):
super().__init__()
if min_val >= max_val:
raise ValueError('min_val must be smaller than max_val')
if min_val <= 0 and logscale:
raise ValueError('min_val must be positive when logscale is true')
self.dim = dim
self.logscale = logscale
self.min_val = min_val
self.max_val = max_val
if logscale:
self.min_val = m.log(self.min_val)
self.max_val = m.log(self.max_val)
self.offset = (self.min_val + self.max_val) / 2
self.scale_factors = nn.Parameter(
torch.arange(1, dim+1, dtype=torch.float32) * torch.pi / (self.max_val - self.min_val)
)
def forward(self, x):
if self.logscale: x = torch.log(x)
x = torch.clip(x, self.min_val, self.max_val) - self.offset
return torch.sin(x.unsqueeze(-1) * self.scale_factors - 0.5)

View File

@@ -0,0 +1,179 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils.layers.silk_upsampler import SilkUpsampler
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.layers.deemph import Deemph
from utils.misc import freeze_model
from models.nns_base import NNSBase
from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
class ShapeUp48(NNSBase):
FRAME_SIZE16k=80
def __init__(self,
num_features=47,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=257,
kernel_size=15,
preemph=0.85,
skip=288,
conv_gain_limits_db=[-6, 6],
numbits_range=[50, 650],
numbits_embedding_dim=8,
hidden_feature_dim=64,
partial_lookahead=True,
norm_p=2,
target_fs=48000,
noise_amplitude=0,
prenet=None,
avg_pool_k=4):
super().__init__(skip=skip, preemph=preemph)
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.skip = skip
self.numbits_range = numbits_range
self.numbits_embedding_dim = numbits_embedding_dim
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
self.frame_size48 = int(self.FRAME_SIZE16k * target_fs / 16000 + .1)
self.frame_size32 = self.FRAME_SIZE16k * 2
self.noise_amplitude = noise_amplitude
self.prenet = prenet
# freeze prenet if given
if prenet is not None:
freeze_model(self.prenet)
try:
self.deemph = Deemph(prenet.preemph)
except:
print("[warning] prenet model is expected to have preemph attribute")
self.deemph = Deemph(0)
# upsampler
self.upsampler = SilkUpsampler()
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# numbits embedding
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
# feature net
if partial_lookahead:
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
# non-linear transforms
self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k)
# spectral shaping
self.af_noise = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=[-30, 0], norm_p=norm_p)
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af2 = LimitedAdaptiveConv1d(3, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE16k
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(2 * rate) + self.af3.flop_count(3 * rate)
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
return feature_net_flops + af_flops
def forward(self, x, features, periods, numbits, debug=False):
if self.prenet is not None:
with torch.no_grad():
x = self.prenet(x, features, periods, numbits)
x = self.deemph(x)
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
cf = self.feature_net(full_features)
y32 = self.upsampler.hq_2x_up(x)
noise = self.noise_amplitude * torch.randn_like(y32)
noise = self.af_noise(noise, cf)
y32 = self.af1(y32, cf, debug=debug)
y32_1 = y32[:, 0:1, :]
y32_2 = self.tdshape1(y32[:, 1:2, :], cf)
y32 = torch.cat((y32_1, y32_2, noise), dim=1)
y32 = self.af2(y32, cf, debug=debug)
y48 = self.upsampler.interpolate_3_2(y32)
y48_1 = y48[:, 0:1, :]
y48_2 = self.tdshape2(y48[:, 1:2, :], cf)
y48 = torch.cat((y48_1, y48_2), dim=1)
y48 = self.af3(y48, cf, debug=debug)
return y48

View File

@@ -0,0 +1,86 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
class SilkFeatureNet(nn.Module):
def __init__(self,
feature_dim=47,
num_channels=256,
lookahead=False):
super(SilkFeatureNet, self).__init__()
self.feature_dim = feature_dim
self.num_channels = num_channels
self.lookahead = lookahead
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
self.conv2 = nn.Conv1d(num_channels, num_channels, 3, dilation=2)
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
def flop_count(self, rate=200):
count = 0
for conv in self.conv1, self.conv2:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
return count
def forward(self, features, state=None):
""" features shape: (batch_size, num_frames, feature_dim) """
batch_size = features.size(0)
if state is None:
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
features = features.permute(0, 2, 1)
if self.lookahead:
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
c = torch.tanh(self.conv2(F.pad(c, [2, 2])))
else:
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
c = torch.tanh(self.conv2(F.pad(c, [4, 0])))
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)
return c

View File

@@ -0,0 +1,127 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import sys
sys.path.append('../dnntools')
import numbers
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from utils.complexity import _conv1d_flop_count
from dnntools.quantization.softquant import soft_quant
from dnntools.sparsification import mark_for_sparsification
class SilkFeatureNetPL(nn.Module):
""" feature net with partial lookahead """
def __init__(self,
feature_dim=47,
num_channels=256,
hidden_feature_dim=64,
softquant=False,
sparsify=True,
sparsification_density=0.5,
apply_weight_norm=False):
super(SilkFeatureNetPL, self).__init__()
if isinstance(sparsification_density, numbers.Number):
sparsification_density = 4 * [sparsification_density]
self.feature_dim = feature_dim
self.num_channels = num_channels
self.hidden_feature_dim = hidden_feature_dim
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1))
self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2))
self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
self.gru = norm(norm(nn.GRU(num_channels, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
if softquant:
self.conv2 = soft_quant(self.conv2)
self.tconv = soft_quant(self.tconv)
self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
if sparsify:
mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4]))
mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
mark_for_sparsification(
self.gru,
{
'W_ir' : (sparsification_density[2], [8, 4], False),
'W_iz' : (sparsification_density[2], [8, 4], False),
'W_in' : (sparsification_density[2], [8, 4], False),
'W_hr' : (sparsification_density[3], [8, 4], True),
'W_hz' : (sparsification_density[3], [8, 4], True),
'W_hn' : (sparsification_density[3], [8, 4], True),
}
)
def flop_count(self, rate=200):
count = 0
for conv in self.conv1, self.conv2, self.tconv:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
return count
def forward(self, features, state=None):
""" features shape: (batch_size, num_frames, feature_dim) """
batch_size = features.size(0)
num_frames = features.size(1)
if state is None:
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
features = features.permute(0, 2, 1)
# dimensionality reduction
c = torch.tanh(self.conv1(features))
# frame accumulation
c = c.permute(0, 2, 1)
c = c.reshape(batch_size, num_frames // 4, -1).permute(0, 2, 1)
c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
# upsampling
c = torch.tanh(self.tconv(c))
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)
return c