add some code

This commit is contained in:
2025-09-05 13:25:11 +08:00
parent 9ff0a99e7a
commit 3cf1229a85
8911 changed files with 2535396 additions and 0 deletions

View File

@@ -0,0 +1,66 @@
# Opus Speech Coding Enhancement
This folder hosts models for enhancing Opus SILK.
## Environment setup
The code is tested with python 3.11. Conda setup is done via
`conda create -n osce python=3.11`
`conda activate osce`
`python -m pip install -r requirements.txt`
## Generating training data
First step is to convert all training items to 16 kHz and 16 bit pcm and then concatenate them. A convenient way to do this is to create a file list and then run
`python scripts/concatenator.py filelist 16000 dataset/clean.s16 --db_min -40 --db_max 0`
which on top provides some random scaling. Data is taken from the datasets listed in dnn/datasets.txt and the exact list of items used for training and validation is
located in dnn/torch/osce/resources.
Second step is to run a patched version of opus_demo in the dataset folder, which will produce the coded output and add feature files. To build the patched opus_demo binary, check out the exp-neural-silk-enhancement branch and build opus_demo the usual way. Then run
`cd dataset && <path_to_patched_opus_demo>/opus_demo voip 16000 1 9000 -silk_random_switching 249 clean.s16 coded.s16 `
The argument to -silk_random_switching specifies the number of frames after which parameters are switched randomly.
## Regression loss based training
Create a default setup for LACE or NoLACE via
`python make_default_setup.py model.yml --model lace/nolace --path2dataset <path2dataset>`
Then run
`python train_model.py model.yml <output folder> --no-redirect`
for running the training script in foreground or
`nohup python train_model.py model.yml <output folder> &`
to run it in background. In the latter case the output is written to `<output folder>/out.txt`.
## Adversarial training (NoLACE only)
Create a default setup for NoLACE via
`python make_default_setup.py nolace_adv.yml --model nolace --adversarial --path2dataset <path2dataset>`
Then run
`python adv_train_model.py nolace_adv.yml <output folder> --no-redirect`
for running the training script in foreground or
`nohup python adv_train_model.py nolace_adv.yml <output folder> &`
to run it in background. In the latter case the output is written to `<output folder>/out.txt`.
## Inference
Generating inference data is analogous to generating training data. Given an item 'item1.wav' run
`mkdir item1.se && sox item1.wav -r 16000 -e signed-integer -b 16 item1.raw && cd item1.se && <path_to_patched_opus_demo>/opus_demo voip 16000 1 <bitrate> ../item1.raw noisy.s16`
The folder item1.se then serves as input for the test_model.py script or for the --testdata argument of train_model.py resp. adv_train_model.py
autogen.sh downloads pre-trained model weights to the subfolder dnn/models of the main repo.

View File

@@ -0,0 +1,462 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
import math as m
import random
import yaml
from tqdm import tqdm
try:
import git
has_git = True
except:
has_git = False
import torch
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from scipy.io import wavfile
import numpy as np
import pesq
from data import SilkEnhancementSet
from models import model_dict
from utils.silk_features import load_inference_data
from utils.misc import count_parameters, retain_grads, get_grad_norm, create_weights
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'name' in setup['model']:
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
model_name = 'pitchpostfilter'
else:
model_name = setup['model']['name']
# prepare output folder
if os.path.exists(args.output):
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
ref = None
if args.testdata is not None:
testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
inference_test = True
inference_folder = os.path.join(args.output, 'inference_test')
os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
try:
ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
except:
pass
else:
inference_test = False
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
lr_gen = lr * setup['training']['gen_lr_reduction']
lambda_feat = setup['training']['lambda_feat']
lambda_reg = setup['training']['lambda_reg']
adv_target = setup['training'].get('adv_target', 'target')
# load training dataset
data_config = setup['data']
data = SilkEnhancementSet(setup['dataset'], **data_config)
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
# create discriminator
disc_name = setup['discriminator']['name']
disc = model_dict[disc_name](
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
)
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
# disc optimizer
parameters = [p for p in disc.parameters() if p.requires_grad]
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location=device)
model.load_state_dict(chkpt['state_dict'])
if 'disc_state_dict' in chkpt:
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
disc.load_state_dict(chkpt['disc_state_dict'])
if 'optimizer_state_dict' in chkpt:
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
if 'disc_optimizer_state_dict' in chkpt:
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
if 'scheduler_state_disc' in chkpt:
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
# if 'torch_rng_state' in chkpt:
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
# torch.set_rng_state(chkpt['torch_rng_state'])
if 'numpy_rng_state' in chkpt:
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
np.random.set_state(chkpt['numpy_rng_state'])
if 'python_rng_state' in chkpt:
print(f"setting Python RNG state from {args.initial_checkpoint}...")
random.setstate(chkpt['python_rng_state'])
# loss
w_l1 = setup['training']['loss']['w_l1']
w_lm = setup['training']['loss']['w_lm']
w_slm = setup['training']['loss']['w_slm']
w_sc = setup['training']['loss']['w_sc']
w_logmel = setup['training']['loss']['w_logmel']
w_wsc = setup['training']['loss']['w_wsc']
w_xcorr = setup['training']['loss']['w_xcorr']
w_sxcorr = setup['training']['loss']['w_sxcorr']
w_l2 = setup['training']['loss']['w_l2']
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
logmelloss = MRLogMelLoss().to(device)
def xcorr_loss(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
return torch.mean(loss)
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
def td_l1(y_true, y_pred, pow=0):
dims = list(range(1, len(y_true.shape)))
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
return torch.mean(tmp)
def criterion(x, y):
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
print("summary:")
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
if ref is not None:
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
print(f"initial MOS (PESQ): {initial_mos}")
best_loss = 1e9
log_interval = 10
m_r = 0
m_f = 0
s_r = 1
s_f = 1
def optimizer_to(optim, device):
for param in optim.state.values():
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
optimizer_to(optimizer, device)
optimizer_to(optimizer_disc, device)
retain_grads(model)
retain_grads(disc)
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
model.to(device)
disc.to(device)
model.train()
disc.train()
running_disc_loss = 0
running_adv_loss = 0
running_feature_loss = 0
running_reg_loss = 0
running_disc_grad_norm = 0
running_model_grad_norm = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target'].to(device)
disc_target = batch[adv_target].to(device)
# calculate model output
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
# discriminator update
scores_gen = disc(output.detach())
scores_real = disc(disc_target.unsqueeze(1))
disc_loss = 0
for score in scores_gen:
disc_loss += (((score[-1]) ** 2)).mean()
m_f = 0.9 * m_f + 0.1 * score[-1].detach().mean().cpu().item()
s_f = 0.9 * s_f + 0.1 * score[-1].detach().std().cpu().item()
for score in scores_real:
disc_loss += (((1 - score[-1]) ** 2)).mean()
m_r = 0.9 * m_r + 0.1 * score[-1].detach().mean().cpu().item()
s_r = 0.9 * s_r + 0.1 * score[-1].detach().std().cpu().item()
disc_loss = 0.5 * disc_loss / len(scores_gen)
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
disc.zero_grad()
disc_loss.backward()
running_disc_grad_norm += get_grad_norm(disc).detach().cpu().item()
optimizer_disc.step()
# generator update
scores_gen = disc(output)
# calculate loss
loss_reg = criterion(output.squeeze(1), target)
num_discs = len(scores_gen)
gen_loss = 0
for score in scores_gen:
gen_loss += (((1 - score[-1]) ** 2)).mean() / num_discs
loss_feat = 0
for k in range(num_discs):
num_layers = len(scores_gen[k]) - 1
f = 4 / num_discs / num_layers
for l in range(num_layers):
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
model.zero_grad()
(gen_loss + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
optimizer.step()
# sparsification
if hasattr(model, 'sparsifier'):
model.sparsifier()
running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
running_adv_loss += gen_loss.detach().cpu().item()
running_disc_loss += disc_loss.detach().cpu().item()
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
model_gradnorm=f"{running_model_grad_norm/(i+1):8.7f}",
disc_gradnorm=f"{running_disc_grad_norm/(i+1):8.7f}",
wc=f"{100*winning_chance:5.2f}%")
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['disc_state_dict'] = disc.state_dict()
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
checkpoint['torch_rng_state'] = torch.get_rng_state()
checkpoint['numpy_rng_state'] = np.random.get_state()
checkpoint['python_rng_state'] = random.getstate()
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
if inference_test:
print("running inference test...")
out = model.process(testsignal, features, periods, numbits).cpu().numpy()
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
if ref is not None:
mos = pesq.pesq(16000, ref, out, mode='wb')
print(f"MOS (PESQ): {mos}")
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
print()
print('Done')

View File

@@ -0,0 +1,451 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
import math as m
import random
import yaml
from tqdm import tqdm
try:
import git
has_git = True
except:
has_git = False
import torch
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from scipy.io import wavfile
import numpy as np
import pesq
from data import LPCNetVocodingDataset
from models import model_dict
from utils.lpcnet_features import load_lpcnet_features
from utils.misc import count_parameters
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'name' in setup['model']:
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
model_name = 'pitchpostfilter'
else:
model_name = setup['model']['name']
# prepare output folder
if os.path.exists(args.output):
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
ref = None
# prepare inference test if wanted
inference_test = False
if type(args.test_features) != type(None):
test_features = load_lpcnet_features(args.test_features)
features = test_features['features']
periods = test_features['periods']
inference_folder = os.path.join(args.output, 'inference_test')
os.makedirs(inference_folder, exist_ok=True)
inference_test = True
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
lr_gen = lr * setup['training']['gen_lr_reduction']
lambda_feat = setup['training']['lambda_feat']
lambda_reg = setup['training']['lambda_reg']
adv_target = setup['training'].get('adv_target', 'target')
# load training dataset
data_config = setup['data']
data = LPCNetVocodingDataset(setup['dataset'], **data_config)
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
# create discriminator
disc_name = setup['discriminator']['name']
disc = model_dict[disc_name](
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
)
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
# disc optimizer
parameters = [p for p in disc.parameters() if p.requires_grad]
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location=device)
model.load_state_dict(chkpt['state_dict'])
if 'disc_state_dict' in chkpt:
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
disc.load_state_dict(chkpt['disc_state_dict'])
if 'optimizer_state_dict' in chkpt:
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
if 'disc_optimizer_state_dict' in chkpt:
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
if 'scheduler_state_disc' in chkpt:
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
# if 'torch_rng_state' in chkpt:
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
# torch.set_rng_state(chkpt['torch_rng_state'])
if 'numpy_rng_state' in chkpt:
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
np.random.set_state(chkpt['numpy_rng_state'])
if 'python_rng_state' in chkpt:
print(f"setting Python RNG state from {args.initial_checkpoint}...")
random.setstate(chkpt['python_rng_state'])
# loss
w_l1 = setup['training']['loss']['w_l1']
w_lm = setup['training']['loss']['w_lm']
w_slm = setup['training']['loss']['w_slm']
w_sc = setup['training']['loss']['w_sc']
w_logmel = setup['training']['loss']['w_logmel']
w_wsc = setup['training']['loss']['w_wsc']
w_xcorr = setup['training']['loss']['w_xcorr']
w_sxcorr = setup['training']['loss']['w_sxcorr']
w_l2 = setup['training']['loss']['w_l2']
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
logmelloss = MRLogMelLoss().to(device)
def xcorr_loss(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
return torch.mean(loss)
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
def td_l1(y_true, y_pred, pow=0):
dims = list(range(1, len(y_true.shape)))
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
return torch.mean(tmp)
def criterion(x, y):
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
print("summary:")
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
if ref is not None:
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
print(f"initial MOS (PESQ): {initial_mos}")
best_loss = 1e9
log_interval = 10
m_r = 0
m_f = 0
s_r = 1
s_f = 1
def optimizer_to(optim, device):
for param in optim.state.values():
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
optimizer_to(optimizer, device)
optimizer_to(optimizer_disc, device)
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
model.to(device)
disc.to(device)
model.train()
disc.train()
running_disc_loss = 0
running_adv_loss = 0
running_feature_loss = 0
running_reg_loss = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target'].to(device)
disc_target = batch[adv_target].to(device)
# calculate model output
output = model(batch['features'], batch['periods'])
# discriminator update
scores_gen = disc(output.detach())
scores_real = disc(disc_target.unsqueeze(1))
disc_loss = 0
for scale in scores_gen:
disc_loss += ((scale[-1]) ** 2).mean()
m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item()
s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item()
for scale in scores_real:
disc_loss += ((1 - scale[-1]) ** 2).mean()
m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item()
s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item()
disc_loss = 0.5 * disc_loss / len(scores_gen)
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
disc.zero_grad()
disc_loss.backward()
optimizer_disc.step()
# generator update
scores_gen = disc(output)
# calculate loss
loss_reg = criterion(output.squeeze(1), target)
num_discs = len(scores_gen)
loss_gen = 0
for scale in scores_gen:
loss_gen += ((1 - scale[-1]) ** 2).mean() / num_discs
loss_feat = 0
for k in range(num_discs):
num_layers = len(scores_gen[k]) - 1
f = 4 / num_discs / num_layers
for l in range(num_layers):
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
model.zero_grad()
(loss_gen + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
optimizer.step()
running_adv_loss += loss_gen.detach().cpu().item()
running_disc_loss += disc_loss.detach().cpu().item()
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
wc=f"{100*winning_chance:5.2f}%")
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['disc_state_dict'] = disc.state_dict()
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
checkpoint['torch_rng_state'] = torch.get_rng_state()
checkpoint['numpy_rng_state'] = np.random.get_state()
checkpoint['python_rng_state'] = random.getstate()
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
if inference_test:
print("running inference test...")
out = model.process(features, periods).cpu().numpy()
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
if ref is not None:
mos = pesq.pesq(16000, ref, out, mode='wb')
print(f"MOS (PESQ): {mos}")
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
print()
print('Done')

View File

@@ -0,0 +1,165 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import torch
import numpy as np
from models import model_dict
from utils import endoscopy
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint_path', type=str, help='path to folder containing checkpoints "lace_checkpoint.pth" and nolace_checkpoint.pth"')
parser.add_argument('output_folder', type=str, help='output folder for testvectors')
parser.add_argument('--debug', action='store_true', help='add debug output to output folder')
def create_adaconv_testvector(prefix, adaconv, num_frames, debug=False):
feature_dim = adaconv.feature_dim
in_channels = adaconv.in_channels
out_channels = adaconv.out_channels
frame_size = adaconv.frame_size
features = torch.randn((1, num_frames, feature_dim))
x_in = torch.randn((1, in_channels, num_frames * frame_size))
x_out = adaconv(x_in, features, debug=debug)
features = features[0].detach().numpy()
x_in = x_in[0].reshape(in_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
x_out = x_out[0].reshape(out_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
features.tofile(prefix + '_features.f32')
x_in.tofile(prefix + '_x_in.f32')
x_out.tofile(prefix + '_x_out.f32')
def create_adacomb_testvector(prefix, adacomb, num_frames, debug=False):
feature_dim = adacomb.feature_dim
in_channels = 1
frame_size = adacomb.frame_size
features = torch.randn((1, num_frames, feature_dim))
x_in = torch.randn((1, in_channels, num_frames * frame_size))
p_in = torch.randint(adacomb.kernel_size, 250, (1, num_frames))
x_out = adacomb(x_in, features, p_in, debug=debug)
features = features[0].detach().numpy()
x_in = x_in[0].permute(1, 0).detach().numpy()
p_in = p_in[0].detach().numpy().astype(np.int32)
x_out = x_out[0].permute(1, 0).detach().numpy()
features.tofile(prefix + '_features.f32')
x_in.tofile(prefix + '_x_in.f32')
p_in.tofile(prefix + '_p_in.s32')
x_out.tofile(prefix + '_x_out.f32')
def create_adashape_testvector(prefix, adashape, num_frames):
feature_dim = adashape.feature_dim
frame_size = adashape.frame_size
features = torch.randn((1, num_frames, feature_dim))
x_in = torch.randn((1, 1, num_frames * frame_size))
x_out = adashape(x_in, features)
features = features[0].detach().numpy()
x_in = x_in.flatten().detach().numpy()
x_out = x_out.flatten().detach().numpy()
features.tofile(prefix + '_features.f32')
x_in.tofile(prefix + '_x_in.f32')
x_out.tofile(prefix + '_x_out.f32')
def create_feature_net_testvector(prefix, model, num_frames):
num_features = model.num_features
num_subframes = 4 * num_frames
input_features = torch.randn((1, num_subframes, num_features))
periods = torch.randint(32, 300, (1, num_subframes))
numbits = model.numbits_range[0] + torch.rand((1, num_frames, 2)) * (model.numbits_range[1] - model.numbits_range[0])
pembed = model.pitch_embedding(periods)
nembed = torch.repeat_interleave(model.numbits_embedding(numbits).flatten(2), 4, dim=1)
full_features = torch.cat((input_features, pembed, nembed), dim=-1)
cf = model.feature_net(full_features)
input_features.float().numpy().tofile(prefix + "_in_features.f32")
periods.numpy().astype(np.int32).tofile(prefix + "_periods.s32")
numbits.float().numpy().tofile(prefix + "_numbits.f32")
full_features.detach().numpy().tofile(prefix + "_full_features.f32")
cf.detach().numpy().tofile(prefix + "_out_features.f32")
if __name__ == "__main__":
args = parser.parse_args()
os.makedirs(args.output_folder, exist_ok=True)
lace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "lace_checkpoint.pth"), map_location='cpu')
nolace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "nolace_checkpoint.pth"), map_location='cpu')
lace = model_dict['lace'](**lace_checkpoint['setup']['model']['kwargs'])
nolace = model_dict['nolace'](**nolace_checkpoint['setup']['model']['kwargs'])
lace.load_state_dict(lace_checkpoint['state_dict'])
nolace.load_state_dict(nolace_checkpoint['state_dict'])
if args.debug:
endoscopy.init(args.output_folder)
# lace af1, 1 input channel, 1 output channel
create_adaconv_testvector(os.path.join(args.output_folder, "lace_af1"), lace.af1, 5, debug=args.debug)
# nolace af1, 1 input channel, 2 output channels
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af1"), nolace.af1, 5, debug=args.debug)
# nolace af4, 2 input channel, 1 output channels
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af4"), nolace.af4, 5, debug=args.debug)
# nolace af2, 2 input channel, 2 output channels
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af2"), nolace.af2, 5, debug=args.debug)
# lace cf1
create_adacomb_testvector(os.path.join(args.output_folder, "lace_cf1"), lace.cf1, 5, debug=args.debug)
# nolace tdshape1
create_adashape_testvector(os.path.join(args.output_folder, "nolace_tdshape1"), nolace.tdshape1, 5)
# lace feature net
create_feature_net_testvector(os.path.join(args.output_folder, 'lace'), lace, 5)
if args.debug:
endoscopy.close()

View File

@@ -0,0 +1,2 @@
from .silk_enhancement_set import SilkEnhancementSet
from .lpcnet_vocoding_dataset import LPCNetVocodingDataset

View File

@@ -0,0 +1,225 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" Dataset for LPCNet training """
import os
import yaml
import torch
import numpy as np
from torch.utils.data import Dataset
scale = 255.0/32768.0
scale_1 = 32768.0/255.0
def ulaw2lin(u):
u = u - 128
s = np.sign(u)
u = np.abs(u)
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
def lin2ulaw(x):
s = np.sign(x)
x = np.abs(x)
u = (s*(128*np.log(1+scale*x)/np.log(256)))
u = np.clip(128 + np.round(u), 0, 255)
return u
def run_lpc(signal, lpcs, frame_length=160):
num_frames, lpc_order = lpcs.shape
prediction = np.concatenate(
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
)
error = signal[lpc_order :] - prediction
return prediction, error
class LPCNetVocodingDataset(Dataset):
def __init__(self,
path_to_dataset,
features=['cepstrum', 'periods', 'pitch_corr'],
target='signal',
frames_per_sample=100,
feature_history=0,
feature_lookahead=0,
lpc_gamma=1):
super().__init__()
# load dataset info
self.path_to_dataset = path_to_dataset
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
dataset = yaml.load(f, yaml.FullLoader)
# dataset version
self.version = dataset['version']
if self.version == 1:
self.getitem = self.getitem_v1
elif self.version == 2:
self.getitem = self.getitem_v2
else:
raise ValueError(f"dataset version {self.version} unknown")
# features
self.feature_history = feature_history
self.feature_lookahead = feature_lookahead
self.frame_offset = 2 + self.feature_history
self.frames_per_sample = frames_per_sample
self.input_features = features
self.feature_frame_layout = dataset['feature_frame_layout']
self.lpc_gamma = lpc_gamma
# load feature file
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
self.feature_frame_length = dataset['feature_frame_length']
assert len(self.features) % self.feature_frame_length == 0
self.features = self.features.reshape((-1, self.feature_frame_length))
# derive number of samples is dataset
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample
# signals
self.frame_length = dataset['frame_length']
self.signal_frame_layout = dataset['signal_frame_layout']
self.target = target
# load signals
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
self.signal_frame_length = dataset['signal_frame_length']
self.signals = self.signals.reshape((-1, self.signal_frame_length))
assert len(self.signals) == len(self.features) * self.frame_length
def __getitem__(self, index):
return self.getitem(index)
def getitem_v2(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
# calculate prediction and error if lpc coefficients present and prediction not given
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
# lpc coefficients with one frame lookahead
# frame positions (start one frame early for past excitation)
frame_start = self.frame_offset + self.frames_per_sample * index - 1
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
# feature positions
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
lpc_order = lpc_stop - lpc_start
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
# LPC weighting
lpc_order = lpc_stop - lpc_start
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
lpcs = lpcs * weights
# signal position (lpc_order samples as history)
signal_start = frame_start * self.frame_length - lpc_order + 1
signal_stop = frame_stop * self.frame_length + 1
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
# extract signals
offset = self.frame_length
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
# calculate error between real signal and noisy prediction
sample['error'] = sample['signal'] - sample['prediction']
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
target = torch.FloatTensor(sample[self.target]) / 2**15
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'target' : target}
def getitem_v1(self, index):
sample = dict()
# extract features
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
for feature in self.input_features:
feature_start, feature_stop = self.feature_frame_layout[feature]
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
# convert periods
if 'periods' in self.input_features:
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
# last_signal and signal are always expected to be there
for signal_name, index in self.signal_frame_layout.items():
sample[signal_name] = self.signals[signal_start : signal_stop, index]
# concatenate features
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
target = torch.LongTensor(sample[self.target])
periods = torch.LongTensor(sample['periods'])
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
def __len__(self):
return self.dataset_length

View File

@@ -0,0 +1,132 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
from torch.utils.data import Dataset
import numpy as np
from utils.silk_features import silk_feature_factory
from utils.pitch import hangover, calculate_acorr_window
class SilkEnhancementSet(Dataset):
def __init__(self,
path,
frames_per_sample=100,
no_pitch_value=9,
acorr_radius=2,
pitch_hangover=8,
num_bands_clean_spec=64,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
add_offset=False,
add_double_lag_acorr=False
):
assert frames_per_sample % 4 == 0
self.frame_size = 80
self.frames_per_sample = frames_per_sample
self.no_pitch_value = no_pitch_value
self.acorr_radius = acorr_radius
self.pitch_hangover = pitch_hangover
self.num_bands_clean_spec = num_bands_clean_spec
self.num_bands_noisy_spec = num_bands_noisy_spec
self.noisy_spec_scale = noisy_spec_scale
self.add_double_lag_acorr = add_double_lag_acorr
self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
self.lpcnet_features = np.from_file(os.path.join(path, 'features_lpcnet.f32'), dtype=np.float32).reshape(-1, 36)
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
self.create_features = silk_feature_factory(no_pitch_value,
acorr_radius,
pitch_hangover,
num_bands_clean_spec,
num_bands_noisy_spec,
noisy_spec_scale,
noisy_apply_dct,
add_offset,
add_double_lag_acorr)
self.history_len = 700 if add_double_lag_acorr else 350
# discard some frames to have enough signal history
self.skip_frames = 4 * ((self.history_len + 319) // 320 + 2)
num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames
self.len = num_frames // frames_per_sample
def __len__(self):
return self.len
def __getitem__(self, index):
frame_start = self.frames_per_sample * index + self.skip_frames
frame_stop = frame_start + self.frames_per_sample
signal_start = frame_start * self.frame_size - self.skip
signal_stop = frame_stop * self.frame_size - self.skip
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
features, periods = self.create_features(
coded_signal,
coded_signal_history,
self.lpcs[frame_start : frame_stop],
self.gains[frame_start : frame_stop],
self.ltps[frame_start : frame_stop],
self.periods[frame_start : frame_stop],
self.offsets[frame_start : frame_stop]
)
lpcnet_features = self.lpcnet_features[frame_start // 2 : frame_stop // 2, :20]
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
return {
'silk_features' : features,
'periods' : periods.astype(np.int64),
'numbits' : numbits.astype(np.float32),
'lpcnet_features' : lpcnet_features
}

View File

@@ -0,0 +1,140 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
from torch.utils.data import Dataset
import numpy as np
from utils.silk_features import silk_feature_factory
from utils.pitch import hangover, calculate_acorr_window
class SilkEnhancementSet(Dataset):
def __init__(self,
path,
frames_per_sample=100,
no_pitch_value=256,
preemph=0.85,
skip=91,
acorr_radius=2,
pitch_hangover=8,
num_bands_clean_spec=64,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
add_double_lag_acorr=False,
):
assert frames_per_sample % 4 == 0
self.frame_size = 80
self.frames_per_sample = frames_per_sample
self.no_pitch_value = no_pitch_value
self.preemph = preemph
self.skip = skip
self.acorr_radius = acorr_radius
self.pitch_hangover = pitch_hangover
self.num_bands_clean_spec = num_bands_clean_spec
self.num_bands_noisy_spec = num_bands_noisy_spec
self.noisy_spec_scale = noisy_spec_scale
self.add_double_lag_acorr = add_double_lag_acorr
self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
self.clean_signal_hp = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
self.clean_signal = np.fromfile(os.path.join(path, 'clean.s16'), dtype=np.int16)
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
self.create_features = silk_feature_factory(no_pitch_value,
acorr_radius,
pitch_hangover,
num_bands_clean_spec,
num_bands_noisy_spec,
noisy_spec_scale,
noisy_apply_dct,
add_double_lag_acorr)
self.history_len = 700 if add_double_lag_acorr else 350
# discard some frames to have enough signal history
self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2)
num_frames = self.clean_signal_hp.shape[0] // 80 - self.skip_frames
self.len = num_frames // frames_per_sample
def __len__(self):
return self.len
def __getitem__(self, index):
frame_start = self.frames_per_sample * index + self.skip_frames
frame_stop = frame_start + self.frames_per_sample
signal_start = frame_start * self.frame_size - self.skip
signal_stop = frame_stop * self.frame_size - self.skip
clean_signal_hp = self.clean_signal_hp[signal_start : signal_stop].astype(np.float32) / 2**15
clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
features, periods = self.create_features(
coded_signal,
coded_signal_history,
self.lpcs[frame_start : frame_stop],
self.gains[frame_start : frame_stop],
self.ltps[frame_start : frame_stop],
self.periods[frame_start : frame_stop]
)
if self.preemph > 0:
clean_signal[1:] -= self.preemph * clean_signal[: -1]
clean_signal_hp[1:] -= self.preemph * clean_signal_hp[: -1]
coded_signal[1:] -= self.preemph * coded_signal[: -1]
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
return {
'features' : features,
'periods' : periods.astype(np.int64),
'target_orig' : clean_signal.astype(np.float32),
'target' : clean_signal_hp.astype(np.float32),
'signals' : coded_signal.reshape(-1, 1).astype(np.float32),
'numbits' : numbits.astype(np.float32)
}

View File

@@ -0,0 +1,103 @@
import torch
from tqdm import tqdm
import sys
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
model.to(device)
model.train()
running_loss = 0
previous_running_loss = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
# calculate loss
if isinstance(output, list):
loss = torch.zeros(1, device=device)
for y in output:
loss = loss + criterion(target, y.squeeze(1))
loss = loss / len(output)
else:
loss = criterion(target, output.squeeze(1))
# calculate gradients
loss.backward()
# update weights
optimizer.step()
# update learning rate
scheduler.step()
# sparsification
if hasattr(model, 'sparsifier'):
model.sparsifier()
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss
def evaluate(model, criterion, dataloader, device, log_interval=10):
model.to(device)
model.eval()
running_loss = 0
previous_running_loss = 0
with torch.no_grad():
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
# calculate loss
loss = criterion(target, output.squeeze(1))
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss

View File

@@ -0,0 +1,101 @@
import torch
from tqdm import tqdm
import sys
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
model.to(device)
model.train()
running_loss = 0
previous_running_loss = 0
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# set gradients to zero
optimizer.zero_grad()
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['features'], batch['periods'])
# calculate loss
if isinstance(output, list):
loss = torch.zeros(1, device=device)
for y in output:
loss = loss + criterion(target, y.squeeze(1))
loss = loss / len(output)
else:
loss = criterion(target, output.squeeze(1))
# calculate gradients
loss.backward()
# update weights
optimizer.step()
# update learning rate
scheduler.step()
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss
def evaluate(model, criterion, dataloader, device, log_interval=10):
model.to(device)
model.eval()
running_loss = 0
previous_running_loss = 0
with torch.no_grad():
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
for i, batch in enumerate(tepoch):
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
target = batch['target']
# calculate model output
output = model(batch['features'], batch['periods'])
# calculate loss
loss = criterion(target, output.squeeze(1))
# update running loss
running_loss += float(loss.cpu())
# update status bar
if i % log_interval == 0:
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
previous_running_loss = running_loss
running_loss /= len(dataloader)
return running_loss

View File

@@ -0,0 +1,174 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
import hashlib
sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
import torch
import wexchange.torch
from wexchange.torch import dump_torch_weights
from models import model_dict
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.misc import remove_all_weight_norm
from wexchange.torch import dump_torch_weights
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint')
parser.add_argument('output_dir', type=str, help='output folder')
parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
sparse_default=False
schedules = {
'nolace': [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None)),
('tdshape1', dict(quantize=True, scale=None)),
('tdshape2', dict(quantize=True, scale=None)),
('tdshape3', dict(quantize=True, scale=None)),
('af2', dict(quantize=True, scale=None)),
('af3', dict(quantize=True, scale=None)),
('af4', dict(quantize=True, scale=None)),
('post_cf1', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_cf2', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_af1', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_af2', dict(quantize=True, scale=None, sparse=sparse_default)),
('post_af3', dict(quantize=True, scale=None, sparse=sparse_default))
],
'lace' : [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None))
]
}
# auxiliary functions
def sha1(filename):
BUF_SIZE = 65536
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
while True:
data = f.read(BUF_SIZE)
if not data:
break
sha1.update(data)
return sha1.hexdigest()
def osce_dump_generic(writer, name, module):
if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv1d) \
or isinstance(module, torch.nn.ConvTranspose1d) or isinstance(module, torch.nn.Embedding) \
or isinstance(module, LimitedAdaptiveConv1d) or isinstance(module, LimitedAdaptiveComb1d) \
or isinstance(module, TDShaper) or isinstance(module, torch.nn.GRU):
dump_torch_weights(writer, module, name=name, verbose=True)
else:
for child_name, child in module.named_children():
osce_dump_generic(writer, (name + "_" + child_name).replace("feature_net", "fnet"), child)
def export_name(name):
name = name.replace('.', '_')
name = name.replace('feature_net', 'fnet')
return name
def osce_scheduled_dump(writer, prefix, model, schedule):
if not prefix.endswith('_'):
prefix += '_'
for name, kwargs in schedule:
dump_torch_weights(writer, model.get_submodule(name), prefix + export_name(name), **kwargs, verbose=True)
if __name__ == "__main__":
args = parser.parse_args()
checkpoint_path = args.checkpoint
outdir = args.output_dir
os.makedirs(outdir, exist_ok=True)
# dump message
message = f"Auto generated from checkpoint {os.path.basename(checkpoint_path)} (sha1: {sha1(checkpoint_path)})"
# create model and load weights
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
model.load_state_dict(checkpoint['state_dict'])
remove_all_weight_norm(model, verbose=True)
# CWriter
model_name = checkpoint['setup']['model']['name']
cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper() + 'Layers', add_typedef=True)
# Add custom includes and global parameters
cwriter.header.write(f'''
#define {model_name.upper()}_PREEMPH {model.preemph}f
#define {model_name.upper()}_FRAME_SIZE {model.FRAME_SIZE}
#define {model_name.upper()}_OVERLAP_SIZE 40
#define {model_name.upper()}_NUM_FEATURES {model.num_features}
#define {model_name.upper()}_PITCH_MAX {model.pitch_max}
#define {model_name.upper()}_PITCH_EMBEDDING_DIM {model.pitch_embedding_dim}
#define {model_name.upper()}_NUMBITS_RANGE_LOW {model.numbits_range[0]}
#define {model_name.upper()}_NUMBITS_RANGE_HIGH {model.numbits_range[1]}
#define {model_name.upper()}_NUMBITS_EMBEDDING_DIM {model.numbits_embedding_dim}
#define {model_name.upper()}_COND_DIM {model.cond_dim}
#define {model_name.upper()}_HIDDEN_FEATURE_DIM {model.hidden_feature_dim}
''')
for i, s in enumerate(model.numbits_embedding.scale_factors):
cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}f\n")
# dump layers
if model_name in schedules and args.quantize:
osce_scheduled_dump(cwriter, model_name, model, schedules[model_name])
else:
osce_dump_generic(cwriter, model_name, model)
cwriter.close()

View File

@@ -0,0 +1,277 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
"""STFT-based Loss modules."""
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import torchaudio
def get_window(win_name, win_length, *args, **kwargs):
window_dict = {
'bartlett_window' : torch.bartlett_window,
'blackman_window' : torch.blackman_window,
'hamming_window' : torch.hamming_window,
'hann_window' : torch.hann_window,
'kaiser_window' : torch.kaiser_window
}
if not win_name in window_dict:
raise ValueError()
return window_dict[win_name](win_length, *args, **kwargs)
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
win = get_window(window, win_length).to(x.device)
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
return torch.clamp(torch.abs(x_stft), min=1e-7)
def spectral_convergence_loss(Y_true, Y_pred):
dims=list(range(1, len(Y_pred.shape)))
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
def log_magnitude_loss(Y_true, Y_pred):
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
def spectral_xcorr_loss(Y_true, Y_pred):
Y_true = Y_true.abs()
Y_pred = Y_pred.abs()
dims=list(range(1, len(Y_pred.shape)))
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
return 1 - xcorr.mean()
class MRLogMelLoss(nn.Module):
def __init__(self,
fft_sizes=[512, 256, 128, 64],
overlap=0.5,
fs=16000,
n_mels=18
):
self.fft_sizes = fft_sizes
self.overlap = overlap
self.fs = fs
self.n_mels = n_mels
super().__init__()
self.mel_specs = []
for fft_size in fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
n_mels = self.n_mels
if fft_size < 128:
n_mels //= 2
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
for i, mel_spec in enumerate(self.mel_specs):
self.add_module(f'mel_spec_{i+1}', mel_spec)
def forward(self, y_true, y_pred):
loss = torch.zeros(1, device=y_true.device)
for mel_spec in self.mel_specs:
Y_true = mel_spec(y_true)
Y_pred = mel_spec(y_pred)
loss = loss + log_magnitude_loss(Y_true, Y_pred)
loss = loss / len(self.mel_specs)
return loss
def create_weight_matrix(num_bins, bins_per_band=10):
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
r0 = bins_per_band // 2
r1 = bins_per_band - r0
for i in range(num_bins):
i0 = max(i - r0, 0)
j0 = min(i + r1, num_bins)
m[i, i0: j0] += 1
if i < r0:
m[i, :r0 - i] += 1
if i > num_bins - r1:
m[i, num_bins - r1 - i:] += 1
return m / bins_per_band
def weighted_spectral_convergence(Y_true, Y_pred, w):
# calculate sfm based weights
logY = torch.log(torch.abs(Y_true) + 1e-9)
Y = torch.abs(Y_true)
avg_logY = torch.matmul(logY.transpose(1, 2), w)
avg_Y = torch.matmul(Y.transpose(1, 2), w)
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
loss = torch.mean(
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
)
return loss
def gen_filterbank(N, Fs=16000):
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
ERB_N = 24.7 + .108*in_freq
delta = np.abs(in_freq-out_freq)/ERB_N
center = (delta<.5).astype('float32')
R = -12*center*delta**2 + (1-center)*(3-12*delta)
RE = 10.**(R/10.)
norm = np.sum(RE, axis=1)
RE = RE/norm[:, np.newaxis]
return torch.from_numpy(RE)
def smooth_log_mag(Y_true, Y_pred, filterbank):
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
loss = torch.abs(
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
)
loss = loss.mean()
return loss
class MRSTFTLoss(nn.Module):
def __init__(self,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
overlap=0.5,
window='hann_window',
fs=16000,
log_mag_weight=1,
sc_weight=0,
wsc_weight=0,
smooth_log_mag_weight=0,
sxcorr_weight=0):
super().__init__()
self.fft_sizes = fft_sizes
self.overlap = overlap
self.window = window
self.log_mag_weight = log_mag_weight
self.sc_weight = sc_weight
self.wsc_weight = wsc_weight
self.smooth_log_mag_weight = smooth_log_mag_weight
self.sxcorr_weight = sxcorr_weight
self.fs = fs
# weights for SFM weighted spectral convergence loss
self.wsc_weights = torch.nn.ParameterDict()
for fft_size in fft_sizes:
width = min(11, int(1000 * fft_size / self.fs + .5))
width += width % 2
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
create_weight_matrix(fft_size // 2 + 1, width),
requires_grad=False
)
# filterbanks for smooth log magnitude loss
self.filterbanks = torch.nn.ParameterDict()
for fft_size in fft_sizes:
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
gen_filterbank(fft_size//2),
requires_grad=False
)
def __call__(self, y_true, y_pred):
lm_loss = torch.zeros(1, device=y_true.device)
sc_loss = torch.zeros(1, device=y_true.device)
wsc_loss = torch.zeros(1, device=y_true.device)
slm_loss = torch.zeros(1, device=y_true.device)
sxcorr_loss = torch.zeros(1, device=y_true.device)
for fft_size in self.fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
win_size = fft_size
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
if self.log_mag_weight > 0:
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
if self.sc_weight > 0:
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
if self.wsc_weight > 0:
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
if self.smooth_log_mag_weight > 0:
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
if self.sxcorr_weight > 0:
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
return total_loss

View File

@@ -0,0 +1,34 @@
import torch
import scipy.signal
from utils.layers.fir import FIR
class TDLowpass(torch.nn.Module):
def __init__(self, numtaps, cutoff, power=2):
super().__init__()
self.b = scipy.signal.firwin(numtaps, cutoff)
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
self.power = power
def forward(self, y_true, y_pred):
assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
diff = y_true - y_pred
diff_lp = torch.nn.functional.conv1d(diff, self.weight)
loss = torch.mean(torch.abs(diff_lp ** self.power))
return loss, diff_lp
def get_freqz(self):
freq, response = scipy.signal.freqz(self.b)
return freq, response

View File

@@ -0,0 +1,93 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import sys
import argparse
import yaml
from utils.templates import setup_dict
parser = argparse.ArgumentParser()
parser.add_argument('name', type=str, help='name of default setup file')
parser.add_argument('--model', choices=['lace', 'nolace', 'lavoce'], help='model name', default='lace')
parser.add_argument('--adversarial', action='store_true', help='setup for adversarial training')
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
args = parser.parse_args()
key = args.model + "_adv" if args.adversarial else args.model
try:
setup = setup_dict[key]
except KeyError:
print("setup not found, adversarial training possibly not specified for model")
sys.exit(1)
# update dataset if given
if type(args.path2dataset) != type(None):
setup['dataset'] = args.path2dataset
name = args.name
if not name.endswith('.yml'):
name += '.yml'
if __name__ == '__main__':
with open(name, 'w') as f:
f.write(yaml.dump(setup))

View File

@@ -0,0 +1,42 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
from .lace import LACE
from .no_lace import NoLACE
from .lavoce import LaVoce
from .lavoce_400 import LaVoce400
from .fd_discriminator import TFDMultiResolutionDiscriminator as FDMResDisc
model_dict = {
'lace': LACE,
'nolace': NoLACE,
'lavoce': LaVoce,
'lavoce400': LaVoce400,
'fdmresdisc': FDMResDisc,
}

View File

@@ -0,0 +1,974 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import math as m
import copy
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import weight_norm, spectral_norm
import torchaudio
from utils.spec import gen_filterbank
# auxiliary functions
def remove_all_weight_norms(module):
for m in module.modules():
if hasattr(m, 'weight_v'):
nn.utils.remove_weight_norm(m)
def create_smoothing_kernel(h, w, gamma=1.5):
ch = h / 2 - 0.5
cw = w / 2 - 0.5
sh = gamma * ch
sw = gamma * cw
vx = ((torch.arange(h) - ch) / sh) ** 2
vy = ((torch.arange(w) - cw) / sw) ** 2
vals = vx.view(-1, 1) + vy.view(1, -1)
kernel = torch.exp(- vals)
kernel = kernel / kernel.sum()
return kernel
def create_kernel(h, w, sh, sw):
# proto kernel gives disjoint partition of 1
proto_kernel = torch.ones((sh, sw))
# create smoothing kernel eta
h_eta, w_eta = h - sh + 1, w - sw + 1
assert h_eta > 0 and w_eta > 0
eta = create_smoothing_kernel(h_eta, w_eta).view(1, 1, h_eta, w_eta)
kernel0 = F.pad(proto_kernel, [w_eta - 1, w_eta - 1, h_eta - 1, h_eta - 1]).unsqueeze(0).unsqueeze(0)
kernel = F.conv2d(kernel0, eta)
return kernel
# positional embeddings
class FrequencyPositionalEmbedding(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
N = x.size(2)
args = torch.arange(0, N, dtype=x.dtype, device=x.device) * torch.pi * 2 / N
cos = torch.cos(args).reshape(1, 1, -1, 1)
sin = torch.sin(args).reshape(1, 1, -1, 1)
zeros = torch.zeros_like(x[:, 0:1, :, :])
y = torch.cat((x, zeros + sin, zeros + cos), dim=1)
return y
class PositionalEmbedding2D(nn.Module):
def __init__(self, d=5):
super().__init__()
self.d = d
def forward(self, x):
N = x.size(2)
M = x.size(3)
h_args = torch.arange(0, N, dtype=x.dtype, device=x.device).reshape(1, 1, -1, 1)
w_args = torch.arange(0, M, dtype=x.dtype, device=x.device).reshape(1, 1, 1, -1)
coeffs = (10000 ** (-2 * torch.arange(0, self.d, dtype=x.dtype, device=x.device) / self.d)).reshape(1, -1, 1, 1)
h_sin = torch.sin(coeffs * h_args)
h_cos = torch.sin(coeffs * h_args)
w_sin = torch.sin(coeffs * w_args)
w_cos = torch.sin(coeffs * w_args)
zeros = torch.zeros_like(x[:, 0:1, :, :])
y = torch.cat((x, zeros + h_sin, zeros + h_cos, zeros + w_sin, zeros + w_cos), dim=1)
return y
# spectral discriminator base class
class SpecDiscriminatorBase(nn.Module):
RECEPTIVE_FIELD_MAX_WIDTH=10000
def __init__(self,
layers,
resolution,
fs=16000,
freq_roi=[50, 7000],
noise_gain=1e-3,
fmap_start_index=0
):
super().__init__()
self.layers = nn.ModuleList(layers)
self.resolution = resolution
self.fs = fs
self.noise_gain = noise_gain
self.fmap_start_index = fmap_start_index
if fmap_start_index >= len(layers):
raise ValueError(f'fmap_start_index is larger than number of layers')
# filter bank for noise shaping
n_fft = resolution[0]
self.filterbank = nn.Parameter(
gen_filterbank(n_fft // 2, fs, keep_size=True),
requires_grad=False
)
# roi bins
f_step = fs / n_fft
self.start_bin = int(m.ceil(freq_roi[0] / f_step - 0.01))
self.stop_bin = min(int(m.floor(freq_roi[1] / f_step + 0.01)), n_fft//2 + 1)
self.init_weights()
# determine receptive field size, offsets and strides
hw = 1000
while True:
x = torch.zeros((1, hw, hw))
with torch.no_grad():
y = self.run_layer_stack(x)[-1]
pos0 = [y.size(-2) // 2, y.size(-1) // 2]
pos1 = [t + 1 for t in pos0]
hs0, ws0 = self._receptive_field((hw, hw), pos0)
hs1, ws1 = self._receptive_field((hw, hw), pos1)
h0 = hs0[1] - hs0[0] + 1
h1 = hs1[1] - hs1[0] + 1
w0 = ws0[1] - ws0[0] + 1
w1 = ws1[1] - ws1[0] + 1
if h0 != h1 or w0 != w1:
hw = 2 * hw
else:
# strides
sh = hs1[0] - hs0[0]
sw = ws1[0] - ws0[0]
if sh == 0 or sw == 0: continue
# offsets
oh = hs0[0] - sh * pos0[0]
ow = ws0[0] - sw * pos0[1]
# overlap factor
overlap = w0 / sw + h0 / sh
#print(f"{w0=} {h0=} {sw=} {sh=} {overlap=}")
self.receptive_field_params = {'width': [sw, ow, w0], 'height': [sh, oh, h0], 'overlap': overlap}
break
if hw > self.RECEPTIVE_FIELD_MAX_WIDTH:
print("warning: exceeded max size while trying to determine receptive field")
# create transposed convolutional kernel
#self.tconv_kernel = nn.Parameter(create_kernel(h0, w0, sw, sw), requires_grad=False)
def run_layer_stack(self, spec):
output = []
x = spec.unsqueeze(1)
for layer in self.layers:
x = layer(x)
output.append(x)
return output
def forward(self, x):
""" returns array with feature maps and final score at index -1 """
output = []
x = self.spectrogram(x)
output = self.run_layer_stack(x)
return output[self.fmap_start_index:]
def receptive_field(self, output_pos):
if self.receptive_field_params is not None:
s, o, h = self.receptive_field_params['height']
h_min = output_pos[0] * s + o + self.start_bin
h_max = h_min + h
h_min = max(h_min, self.start_bin)
h_max = min(h_max, self.stop_bin)
s, o, w = self.receptive_field_params['width']
w_min = output_pos[1] * s + o
w_max = w_min + w
return (h_min, h_max), (w_min, w_max)
else:
return None, None
def _receptive_field(self, input_dims, output_pos):
""" determines receptive field probabilistically via autograd (slow) """
x = torch.randn((1,) + input_dims, requires_grad=True)
# run input through layers
y = self.run_layer_stack(x)[-1]
b, c, h, w = y.shape
if output_pos[0] >= h or output_pos[1] >= w:
raise ValueError("position out of range")
mask = torch.zeros((b, c, h, w))
mask[0, 0, output_pos[0], output_pos[1]] = 1
(mask * y).sum().backward()
hs, ws = torch.nonzero(x.grad[0], as_tuple=True)
h_min, h_max = hs.min().item(), hs.max().item()
w_min, w_max = ws.min().item(), ws.max().item()
return [h_min, h_max], [w_min, w_max]
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def spectrogram(self, x):
n_fft, hop_length, win_length = self.resolution
x = x.squeeze(1)
window = getattr(torch, 'hann_window')(win_length).to(x.device)
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
window=window, return_complex=True) #[B, F, T]
x = torch.abs(x)
# noise floor following spectral envelope
smoothed_x = torch.matmul(self.filterbank, x)
noise = torch.randn_like(x) * smoothed_x * self.noise_gain
x = x + noise
# frequency ROI
x = x[:, self.start_bin : self.stop_bin + 1, ...]
return torchaudio.functional.amplitude_to_DB(x,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)#torch.sqrt(x)
def grad_map(self, x):
self.zero_grad()
n_fft, hop_length, win_length = self.resolution
window = getattr(torch, 'hann_window')(win_length).to(x.device)
y = torch.stft(x.squeeze(1), n_fft=n_fft, hop_length=hop_length, win_length=win_length,
window=window, return_complex=True) #[B, F, T]
y = torch.abs(y)
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
specgram.requires_grad = True
specgram.retain_grad()
if specgram.grad is not None:
specgram.grad.zero_()
y = specgram[:, self.start_bin : self.stop_bin + 1, ...]
scores = self.run_layer_stack(y)[-1]
loss = torch.mean((1 - scores) ** 2)
loss.backward()
return specgram.data[0], torch.abs(specgram.grad)[0]
def relevance_map(self, x):
n_fft, hop_length, win_length = self.resolution
y = x.view(-1)
window = getattr(torch, 'hann_window')(win_length).to(x.device)
y = torch.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
window=window, return_complex=True) #[B, F, T]
y = torch.abs(y)
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
scores = self.forward(x)[-1]
sh, _, h = self.receptive_field_params['height']
sw, _, w = self.receptive_field_params['width']
kernel = create_kernel(h, w, sh, sw).float().to(scores.device)
with torch.no_grad():
pad_w = (w + sw - 1) // sw
pad_h = (h + sh - 1) // sh
padded_scores = F.pad(scores, (pad_w, pad_w, pad_h, pad_h), mode='replicate')
# CAVE: padding should be derived from offsets
rv = F.conv_transpose2d(padded_scores, kernel, bias=None, stride=(sh, sw), padding=(h//2, w//2))
rv = rv[..., pad_h * sh : - pad_h * sh, pad_w * sw : -pad_w * sw]
relevance = torch.zeros_like(specgram)
relevance[..., self.start_bin : self.start_bin + rv.size(-2), : rv.size(-1)] = rv
return specgram, relevance
def lrp(self, x, eps=1e-9, label='both', threshold=0.5, low=None, high=None, verbose=False):
""" layer-wise relevance propagation (https://git.tu-berlin.de/gmontavon/lrp-tutorial) """
# ToDo: this code is highly unsafe as it assumes that layers are nn.Sequential with suitable activations
def newconv2d(layer,g):
new_layer = nn.Conv2d(layer.in_channels,
layer.out_channels,
layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
dilation=layer.dilation,
groups=layer.groups)
try: new_layer.weight = nn.Parameter(g(layer.weight.data.clone()))
except AttributeError: pass
try: new_layer.bias = nn.Parameter(g(layer.bias.data.clone()))
except AttributeError: pass
return new_layer
bounds = {
64: [-85.82449722290039, 2.1755014657974243],
128: [-84.49211349487305, 3.5078893899917607],
256: [-80.33127822875977, 7.6687201976776125],
512: [-73.79328079223633, 14.20672025680542],
1024: [-67.59239501953125, 20.40760498046875],
2048: [-62.31902580261231, 25.680974197387698],
}
nfft = self.resolution[0]
if low is None: low = bounds[nfft][0]
if high is None: high = bounds[nfft][1]
remove_all_weight_norms(self)
for p in self.parameters():
if p.grad is not None:
p.grad.zero_()
num_layers = len(self.layers)
X = self.spectrogram(x). detach()
# forward pass
A = [X.unsqueeze(1)] + [None] * len(self.layers)
for i in range(num_layers - 1):
A[i + 1] = self.layers[i](A[i])
# initial relevance is last layer without activation
r = A[-2]
last_layer_rs = [r]
layer = self.layers[-1]
for sublayer in list(layer)[:-1]:
r = sublayer(r)
last_layer_rs.append(r)
mask = torch.zeros_like(r)
mask.requires_grad_(False)
if verbose:
print(r.min(), r.max())
if label in {'both', 'fake'}:
mask[r < -threshold] = 1
if label in {'both', 'real'}:
mask[r > threshold] = 1
r = r * mask
# backward pass
R = [None] * num_layers + [r]
for l in range(1, num_layers)[::-1]:
A[l] = (A[l]).data.requires_grad_(True)
layer = nn.Sequential(*(list(self.layers[l])[:-1]))
z = layer(A[l]) + eps
s = (R[l+1] / z).data
(z*s).sum().backward()
c = A[l].grad
R[l] = (A[l] * c).data
# first layer
A[0] = (A[0].data).requires_grad_(True)
Xl = (torch.zeros_like(A[0].data) + low).requires_grad_(True)
Xh = (torch.zeros_like(A[0].data) + high).requires_grad_(True)
if len(list(self.layers)) > 2:
# unsafe way to check for embedding layer
embed = list(self.layers[0])[0]
conv = list(self.layers[0])[1]
layer = nn.Sequential(embed, conv)
layerl = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(min=0)))
layerh = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(max=0)))
else:
layer = list(self.layers[0])[0]
layerl = newconv2d(layer, lambda p: p.clamp(min=0))
layerh = newconv2d(layer, lambda p: p.clamp(max=0))
z = layer(A[0])
z -= layerl(Xl) + layerh(Xh)
s = (R[1] / z).data
(z * s).sum().backward()
c, cp, cm = A[0].grad, Xl.grad, Xh.grad
R[0] = (A[0] * c + Xl * cp + Xh * cm)
#R[0] = (A[0] * c).data
return X, R[0].mean(dim=1)
def create_3x3_conv_plan(num_layers : int,
f_stretch : int,
f_down : int,
t_stretch : int,
t_down : int
):
""" creates a stride, dilation, padding plan for a 2d conv network
Args:
num_layers (int): number of layers
f_stretch (int): log_2 of stretching factor along frequency axis
f_down (int): log_2 of downsampling factor along frequency axis
t_stretch (int): log_2 of stretching factor along time axis
t_down (int): log_2 of downsampling factor along time axis
Returns:
list(list(tuple)): list containing entries [(stride_t, stride_f), (dilation_t, dilation_f), (padding_t, padding_f)]
"""
assert num_layers > 0 and t_stretch >= 0 and t_down >= 0 and f_stretch >= 0 and f_down >= 0
assert f_stretch < num_layers and t_stretch < num_layers
def process_dimension(n_layers, stretch, down):
stack_layers = n_layers - 1
stride_layers = min(min(down, stretch) , stack_layers)
dilation_layers = max(min(stack_layers - stride_layers - 1, stretch - stride_layers), 0)
final_stride = 2 ** (max(down - stride_layers, 0))
final_dilation = 1
if stride_layers < stack_layers and stretch - stride_layers - dilation_layers > 0:
final_dilation = 2
strides, dilations, paddings = [], [], []
processed_layers = 0
current_dilation = 1
for _ in range(stride_layers):
# increase receptive field and downsample via stride = 2
strides.append(2)
dilations.append(1)
paddings.append(1)
processed_layers += 1
if processed_layers < stack_layers:
strides.append(1)
dilations.append(1)
paddings.append(1)
processed_layers += 1
for _ in range(dilation_layers):
# increase receptive field via dilation = 2
strides.append(1)
current_dilation *= 2
dilations.append(current_dilation)
paddings.append(current_dilation)
processed_layers += 1
while processed_layers < n_layers - 1:
# fill up with std layers
strides.append(1)
dilations.append(current_dilation)
paddings.append(current_dilation)
processed_layers += 1
# final layer
strides.append(final_stride)
current_dilation * final_dilation
dilations.append(current_dilation)
paddings.append(current_dilation)
processed_layers += 1
assert processed_layers == n_layers
return strides, dilations, paddings
t_strides, t_dilations, t_paddings = process_dimension(num_layers, t_stretch, t_down)
f_strides, f_dilations, f_paddings = process_dimension(num_layers, f_stretch, f_down)
plan = []
for i in range(num_layers):
plan.append([
(f_strides[i], t_strides[i]),
(f_dilations[i], t_dilations[i]),
(f_paddings[i], t_paddings[i]),
])
return plan
class DiscriminatorExperimental(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=512,
num_layers=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
layers = []
stride = (2, 1)
padding= (1, 1)
in_channels = 1 + 2
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.ReLU(inplace=True)
)
)
in_channels = out_channels + 2
out_channels = min(2 * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)),
nn.Sigmoid()
)
)
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
# bias biases
bias_val = 0.1
with torch.no_grad():
for name, weight in self.named_parameters():
if 'bias' in name:
weight = weight + bias_val
configs = {
'f_down': {
'stretch' : {
64 : (0, 0),
128: (1, 0),
256: (2, 0),
512: (3, 0),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 0),
128: (1, 0),
256: (2, 0),
512: (3, 0),
1024: (4, 0),
2048: (5, 0)
}
},
'ft_down': {
'stretch' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
}
},
'dilated': {
'stretch' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 0),
128: (0, 0),
256: (0, 0),
512: (0, 0),
1024: (0, 0),
2048: (0, 0)
}
},
'mixed': {
'stretch' : {
64 : (0, 4),
128: (1, 3),
256: (2, 2),
512: (3, 1),
1024: (4, 0),
2048: (5, 0)
},
'down' : {
64 : (0, 0),
128: (1, 0),
256: (2, 0),
512: (3, 0),
1024: (4, 0),
2048: (5, 0)
}
},
}
class DiscriminatorMagFree(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=256,
num_layers=5,
use_spectral_norm=False,
design=None):
if design is None:
raise ValueError('error: arch required in DiscriminatorMagFree')
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
stretch = configs[design]['stretch'][resolution[0]]
down = configs[design]['down'][resolution[0]]
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
self.stretch = stretch
self.down = down
layers = []
plan = create_3x3_conv_plan(num_layers + 1, stretch[0], down[0], stretch[1], down[1])
in_channels = 1 + 2
out_channels = self.num_channels
for i in range(self.num_layers):
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=plan[i][0], dilation=plan[i][1], padding=plan[i][2])),
nn.ReLU(inplace=True)
)
)
in_channels = out_channels + 2
# product over strides
channel_factor = plan[i][0][0] * plan[i][0][1]
out_channels = min(channel_factor * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), stride=plan[-1][0], dilation=plan[-1][1], padding=plan[-1][2])),
nn.Sigmoid()
)
)
# for layer in layers:
# print(layer)
# print("end\n\n")
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
# bias biases
bias_val = 0.1
with torch.no_grad():
for name, weight in self.named_parameters():
if 'bias' in name:
weight = weight + bias_val
class DiscriminatorMagFreqPosition(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=512,
num_layers=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
layers = []
stride = (2, 1)
padding= (1, 1)
in_channels = 1 + 2
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.LeakyReLU(0.2, inplace=True)
)
)
in_channels = out_channels + 2
out_channels = min(2 * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
FrequencyPositionalEmbedding(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
)
)
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
class DiscriminatorMag2dPositional(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=16,
max_channels=512,
num_layers=5,
d=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.resolution = resolution
self.num_channels = num_channels
self.num_channels_max = max_channels
self.num_layers = num_layers
self.d = d
embedding_dim = 4 * d
layers = []
stride = (2, 2)
padding= (1, 1)
in_channels = 1 + embedding_dim
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
PositionalEmbedding2D(d),
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.LeakyReLU(0.2, inplace=True)
)
)
in_channels = out_channels + embedding_dim
out_channels = min(2 * out_channels, self.num_channels_max)
layers.append(
nn.Sequential(
PositionalEmbedding2D(),
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
)
)
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
class DiscriminatorMag(SpecDiscriminatorBase):
def __init__(self,
resolution,
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
num_channels=32,
num_layers=5,
use_spectral_norm=False):
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.num_channels = num_channels
self.num_layers = num_layers
layers = []
stride = (1, 1)
padding= (1, 1)
in_channels = 1
out_channels = self.num_channels
for _ in range(self.num_layers):
layers.append(
nn.Sequential(
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
nn.LeakyReLU(0.2, inplace=True)
)
)
in_channels = out_channels
layers.append(norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)))
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
discriminators = {
'mag': DiscriminatorMag,
'freqpos': DiscriminatorMagFreqPosition,
'2dpos': DiscriminatorMag2dPositional,
'experimental': DiscriminatorExperimental,
'free': DiscriminatorMagFree
}
class TFDMultiResolutionDiscriminator(torch.nn.Module):
def __init__(self,
fft_sizes_16k=[64, 128, 256, 512, 1024, 2048],
architecture='mag',
fs=16000,
freq_roi=[50, 7400],
noise_gain=0,
use_spectral_norm=False,
**kwargs):
super().__init__()
fft_sizes = [int(round(fft_size_16k * fs / 16000)) for fft_size_16k in fft_sizes_16k]
resolutions = [[n_fft, n_fft // 4, n_fft] for n_fft in fft_sizes]
Disc = discriminators[architecture]
discs = [Disc(resolutions[i], fs=fs, freq_roi=freq_roi, noise_gain=noise_gain, use_spectral_norm=use_spectral_norm, **kwargs) for i in range(len(resolutions))]
self.discriminators = nn.ModuleList(discs)
def forward(self, y):
outputs = []
for disc in self.discriminators:
outputs.append(disc(y))
return outputs
class FWGAN_disc_wrapper(nn.Module):
def __init__(self, disc):
super().__init__()
self.disc = disc
def forward(self, y, y_hat):
out_real = self.disc(y)
out_fake = self.disc(y_hat)
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for y_real, y_fake in zip(out_real, out_fake):
y_d_rs.append(y_real[-1])
y_d_gs.append(y_fake[-1])
fmap_rs.append(y_real[:-1])
fmap_gs.append(y_fake[:-1])
return y_d_rs, y_d_gs, fmap_rs, fmap_gs

View File

@@ -0,0 +1,190 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from models.nns_base import NNSBase
from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
import sys
sys.path.append('../dnntools')
from dnntools.sparsification import create_sparsifier
class LACE(NNSBase):
""" Linear-Adaptive Coding Enhancer """
FRAME_SIZE=80
def __init__(self,
num_features=47,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=257,
kernel_size=15,
preemph=0.85,
skip=91,
comb_gain_limit_db=-6,
global_gain_limits_db=[-6, 6],
conv_gain_limits_db=[-6, 6],
numbits_range=[50, 650],
numbits_embedding_dim=8,
hidden_feature_dim=64,
partial_lookahead=True,
norm_p=2,
softquant=False,
sparsify=False,
sparsification_schedule=[10000, 30000, 100],
sparsification_density=0.5,
apply_weight_norm=False):
super().__init__(skip=skip, preemph=preemph)
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.skip = skip
self.numbits_range = numbits_range
self.numbits_embedding_dim = numbits_embedding_dim
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# numbits embedding
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
# feature net
if partial_lookahead:
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
if sparsify:
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
af_flops = self.af1.flop_count(rate)
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
return feature_net_flops + comb_flops + af_flops
def forward(self, x, features, periods, numbits, debug=False):
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
cf = self.feature_net(full_features)
y = self.cf1(x, cf, periods, debug=debug)
y = self.cf2(y, cf, periods, debug=debug)
y = self.af1(y, cf, debug=debug)
return y
def get_impulse_responses(self, features, periods, numbits):
""" generates impoulse responses on frame centers (input without batch dimension) """
num_frames = features.size(0)
batch_size = 32
max_len = 2 * (self.pitch_max + self.kernel_size) + 10
# spread out some pulses
x = np.zeros((batch_size, 1, num_frames * self.FRAME_SIZE))
for b in range(batch_size):
x[b, :, self.FRAME_SIZE // 2 + b * self.FRAME_SIZE :: batch_size * self.FRAME_SIZE] = 1
# prepare input
x = torch.from_numpy(x).float().to(features.device)
features = torch.repeat_interleave(features.unsqueeze(0), batch_size, 0)
periods = torch.repeat_interleave(periods.unsqueeze(0), batch_size, 0)
numbits = torch.repeat_interleave(numbits.unsqueeze(0), batch_size, 0)
# run network
with torch.no_grad():
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
cf = self.feature_net(full_features)
y = self.cf1(x, cf, periods, debug=False)
y = self.cf2(y, cf, periods, debug=False)
y = self.af1(y, cf, debug=False)
# collect responses
y = y.detach().squeeze().cpu().numpy()
cut_frames = (max_len + self.FRAME_SIZE - 1) // self.FRAME_SIZE
num_responses = num_frames - cut_frames
responses = np.zeros((num_responses, max_len))
for i in range(num_responses):
b = i % batch_size
start = self.FRAME_SIZE // 2 + i * self.FRAME_SIZE
stop = start + max_len
responses[i, :] = y[b, start:stop]
return responses

View File

@@ -0,0 +1,274 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.layers.noise_shaper import NoiseShaper
from utils.complexity import _conv1d_flop_count
from utils.endoscopy import write_data
from models.nns_base import NNSBase
from models.lpcnet_feature_net import LPCNetFeatureNet
from .scale_embedding import ScaleEmbedding
def print_channels(y, prefix="", name="", rate=16000):
num_channels = y.size(1)
for i in range(num_channels):
channel_name = f"{prefix}_c{i:02d}"
if len(name) > 0: channel_name += "_" + name
ch = y[0,i,:].detach().cpu().numpy()
ch = ((2**14) * ch / np.max(ch)).astype(np.int16)
write_data(channel_name, ch, rate)
class LaVoce(nn.Module):
""" Linear-Adaptive VOCodEr """
FEATURE_FRAME_SIZE=160
FRAME_SIZE=80
def __init__(self,
num_features=20,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=300,
kernel_size=15,
preemph=0.85,
comb_gain_limit_db=-6,
global_gain_limits_db=[-6, 6],
conv_gain_limits_db=[-6, 6],
norm_p=2,
avg_pool_k=4,
pulses=False,
innovate1=True,
innovate2=False,
innovate3=False,
ftrans_k=2):
super().__init__()
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.pulses = pulses
self.ftrans_k = ftrans_k
assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# feature net
self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
# noise shaper
self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# non-linear transforms
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate1)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate2)
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate3)
# combinators
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# feature transforms
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
def create_phase_signals(self, periods):
batch_size = periods.size(0)
progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
progression = torch.repeat_interleave(progression, batch_size, 0)
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
chunks = []
for sframe in range(periods.size(1)):
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
if self.pulses:
alpha = torch.cos(f).view(batch_size, 1, 1)
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
chunk = torch.cat((pulse_a, pulse_b), dim = 1)
else:
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
phase0 = phase0 + self.FRAME_SIZE * f
chunks.append(chunk)
phase_signals = torch.cat(chunks, dim=-1)
return phase_signals
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
return feature_net_flops + comb_flops + af_flops + feature_flops
def feature_transform(self, f, layer):
f = f.permute(0, 2, 1)
f = F.pad(f, [self.ftrans_k - 1, 0])
f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, features, periods, debug=False):
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
full_features = torch.cat((features, pitch_embedding), dim=-1)
cf = self.feature_net(full_features)
# upsample periods
periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
# pre-net
ref_phase = torch.tanh(self.create_phase_signals(periods))
if debug: print_channels(ref_phase, prefix="lavoce_01", name="pulse")
x = self.af_prescale(ref_phase, cf)
noise = self.noise_shaper(cf)
if debug: print_channels(torch.cat((x, noise), dim=1), prefix="lavoce_02", name="inputs")
y = self.af_mix(torch.cat((x, noise), dim=1), cf)
if debug: print_channels(y, prefix="lavoce_03", name="postselect1")
# temporal shaping + innovating
y1 = y[:, 0:1, :]
y2 = self.tdshape1(y[:, 1:2, :], cf)
if debug: print_channels(y2, prefix="lavoce_04", name="postshape1")
y = torch.cat((y1, y2), dim=1)
y = self.af2(y, cf, debug=debug)
if debug: print_channels(y, prefix="lavoce_05", name="postselect2")
cf = self.feature_transform(cf, self.post_af2)
y1 = y[:, 0:1, :]
y2 = self.tdshape2(y[:, 1:2, :], cf)
if debug: print_channels(y2, prefix="lavoce_06", name="postshape2")
y = torch.cat((y1, y2), dim=1)
y = self.af3(y, cf, debug=debug)
if debug: print_channels(y, prefix="lavoce_07", name="postmix1")
cf = self.feature_transform(cf, self.post_af3)
# spectral shaping
y = self.cf1(y, cf, periods, debug=debug)
if debug: print_channels(y, prefix="lavoce_08", name="postcomb1")
cf = self.feature_transform(cf, self.post_cf1)
y = self.cf2(y, cf, periods, debug=debug)
if debug: print_channels(y, prefix="lavoce_09", name="postcomb2")
cf = self.feature_transform(cf, self.post_cf2)
y = self.af1(y, cf, debug=debug)
if debug: print_channels(y, prefix="lavoce_10", name="postselect3")
cf = self.feature_transform(cf, self.post_af1)
# final temporal env adjustment
y1 = y[:, 0:1, :]
y2 = self.tdshape3(y[:, 1:2, :], cf)
if debug: print_channels(y2, prefix="lavoce_11", name="postshape3")
y = torch.cat((y1, y2), dim=1)
y = self.af4(y, cf, debug=debug)
if debug: print_channels(y, prefix="lavoce_12", name="postmix2")
return y
def process(self, features, periods, debug=False):
self.eval()
device = next(iter(self.parameters())).device
with torch.no_grad():
# run model
f = features.unsqueeze(0).to(device)
p = periods.unsqueeze(0).to(device)
y = self.forward(f, p, debug=debug).squeeze()
# deemphasis
if self.preemph > 0:
for i in range(len(y) - 1):
y[i + 1] += self.preemph * y[i]
# clip to valid range
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
return out

View File

@@ -0,0 +1,254 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.layers.noise_shaper import NoiseShaper
from utils.complexity import _conv1d_flop_count
from utils.endoscopy import write_data
from models.nns_base import NNSBase
from models.lpcnet_feature_net import LPCNetFeatureNet
from .scale_embedding import ScaleEmbedding
class LaVoce400(nn.Module):
""" Linear-Adaptive VOCodEr """
FEATURE_FRAME_SIZE=160
FRAME_SIZE=40
def __init__(self,
num_features=20,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=300,
kernel_size=15,
preemph=0.85,
comb_gain_limit_db=-6,
global_gain_limits_db=[-6, 6],
conv_gain_limits_db=[-6, 6],
norm_p=2,
avg_pool_k=4,
pulses=False):
super().__init__()
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.pulses = pulses
assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# feature net
self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
# noise shaper
self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# non-linear transforms
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=True)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
# combinators
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
# feature transforms
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2)
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2)
def create_phase_signals(self, periods):
batch_size = periods.size(0)
progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
progression = torch.repeat_interleave(progression, batch_size, 0)
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
chunks = []
for sframe in range(periods.size(1)):
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
if self.pulses:
alpha = torch.cos(f).view(batch_size, 1, 1)
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
chunk = torch.cat((pulse_a, pulse_b), dim = 1)
else:
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
phase0 = phase0 + self.FRAME_SIZE * f
chunks.append(chunk)
phase_signals = torch.cat(chunks, dim=-1)
return phase_signals
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
return feature_net_flops + comb_flops + af_flops + feature_flops
def feature_transform(self, f, layer):
f = f.permute(0, 2, 1)
f = F.pad(f, [1, 0])
f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, features, periods, debug=False):
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
full_features = torch.cat((features, pitch_embedding), dim=-1)
cf = self.feature_net(full_features)
# upsample periods
periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
# pre-net
ref_phase = torch.tanh(self.create_phase_signals(periods))
x = self.af_prescale(ref_phase, cf)
noise = self.noise_shaper(cf)
y = self.af_mix(torch.cat((x, noise), dim=1), cf)
if debug:
ch0 = y[0,0,:].detach().cpu().numpy()
ch1 = y[0,1,:].detach().cpu().numpy()
ch0 = (2**15 * ch0 / np.max(ch0)).astype(np.int16)
ch1 = (2**15 * ch1 / np.max(ch1)).astype(np.int16)
write_data('prior_channel0', ch0, 16000)
write_data('prior_channel1', ch1, 16000)
# temporal shaping + innovating
y1 = y[:, 0:1, :]
y2 = self.tdshape1(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af2(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af2)
y1 = y[:, 0:1, :]
y2 = self.tdshape2(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af3(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af3)
# spectral shaping
y = self.cf1(y, cf, periods, debug=debug)
cf = self.feature_transform(cf, self.post_cf1)
y = self.cf2(y, cf, periods, debug=debug)
cf = self.feature_transform(cf, self.post_cf2)
y = self.af1(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af1)
# final temporal env adjustment
y1 = y[:, 0:1, :]
y2 = self.tdshape3(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af4(y, cf, debug=debug)
return y
def process(self, features, periods, debug=False):
self.eval()
device = next(iter(self.parameters())).device
with torch.no_grad():
# run model
f = features.unsqueeze(0).to(device)
p = periods.unsqueeze(0).to(device)
y = self.forward(f, p, debug=debug).squeeze()
# deemphasis
if self.preemph > 0:
for i in range(len(y) - 1):
y[i + 1] += self.preemph * y[i]
# clip to valid range
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
return out

View File

@@ -0,0 +1,91 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
class LPCNetFeatureNet(nn.Module):
def __init__(self,
feature_dim=84,
num_channels=256,
upsamp_factor=2,
lookahead=True):
super().__init__()
self.feature_dim = feature_dim
self.num_channels = num_channels
self.upsamp_factor = upsamp_factor
self.lookahead = lookahead
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
self.conv2 = nn.Conv1d(num_channels, num_channels, 3)
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
self.tconv = nn.ConvTranspose1d(num_channels, num_channels, upsamp_factor, upsamp_factor)
def flop_count(self, rate=100):
count = 0
for conv in self.conv1, self.conv2, self.tconv:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
return count
def forward(self, features, state=None):
""" features shape: (batch_size, num_frames, feature_dim) """
batch_size = features.size(0)
if state is None:
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
features = features.permute(0, 2, 1)
if self.lookahead:
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
else:
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
c = torch.tanh(self.tconv(c))
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)
return c

View File

@@ -0,0 +1,69 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
class NNSBase(nn.Module):
def __init__(self, skip=91, preemph=0.85):
super().__init__()
self.skip = skip
self.preemph = preemph
def process(self, sig, features, periods, numbits, debug=False):
self.eval()
has_numbits = 'numbits' in self.forward.__code__.co_varnames
device = next(iter(self.parameters())).device
with torch.no_grad():
# run model
x = sig.view(1, 1, -1).to(device)
f = features.unsqueeze(0).to(device)
p = periods.unsqueeze(0).to(device)
n = numbits.unsqueeze(0).to(device)
if has_numbits:
y = self.forward(x, f, p, n, debug=debug).squeeze()
else:
y = self.forward(x, f, p, debug=debug).squeeze()
# deemphasis
if self.preemph > 0:
for i in range(len(y) - 1):
y[i + 1] += self.preemph * y[i]
# delay compensation
y = torch.cat((y[self.skip:], torch.zeros(self.skip, dtype=y.dtype, device=y.device)))
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
return out

View File

@@ -0,0 +1,218 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numbers
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import numpy as np
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.complexity import _conv1d_flop_count
from models.nns_base import NNSBase
from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
import sys
sys.path.append('../dnntools')
from dnntools.quantization import soft_quant
from dnntools.sparsification import create_sparsifier, mark_for_sparsification
class NoLACE(NNSBase):
""" Non-Linear Adaptive Coding Enhancer """
FRAME_SIZE=80
def __init__(self,
num_features=47,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=257,
kernel_size=15,
preemph=0.85,
skip=91,
comb_gain_limit_db=-6,
global_gain_limits_db=[-6, 6],
conv_gain_limits_db=[-6, 6],
numbits_range=[50, 650],
numbits_embedding_dim=8,
hidden_feature_dim=64,
partial_lookahead=True,
norm_p=2,
avg_pool_k=4,
pool_after=False,
softquant=False,
sparsify=False,
sparsification_schedule=[100, 1000, 100],
sparsification_density=0.5,
apply_weight_norm=False):
super().__init__(skip=skip, preemph=preemph)
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.skip = skip
self.numbits_range = numbits_range
self.numbits_embedding_dim = numbits_embedding_dim
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
if isinstance(sparsification_density, numbers.Number):
sparsification_density = 10 * [sparsification_density]
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# numbits embedding
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
# feature net
if partial_lookahead:
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
# comb filters
left_pad = self.kernel_size // 2
right_pad = self.kernel_size - 1 - left_pad
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# non-linear transforms
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
# combinators
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
# feature transforms
self.post_cf1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_cf2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_af1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_af2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
self.post_af3 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
if softquant:
self.post_cf1 = soft_quant(self.post_cf1)
self.post_cf2 = soft_quant(self.post_cf2)
self.post_af1 = soft_quant(self.post_af1)
self.post_af2 = soft_quant(self.post_af2)
self.post_af3 = soft_quant(self.post_af3)
if sparsify:
mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4]))
mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4]))
mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4]))
mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4]))
mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4]))
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate)
shape_flops = self.tdshape1.flop_count(rate) + self.tdshape2.flop_count(rate) + self.tdshape3.flop_count(rate)
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops
def feature_transform(self, f, layer):
f0 = f.permute(0, 2, 1)
f = F.pad(f0, [1, 0])
f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, x, features, periods, numbits, debug=False):
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
cf = self.feature_net(full_features)
y = self.cf1(x, cf, periods, debug=debug)
cf = self.feature_transform(cf, self.post_cf1)
y = self.cf2(y, cf, periods, debug=debug)
cf = self.feature_transform(cf, self.post_cf2)
y = self.af1(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af1)
y1 = y[:, 0:1, :]
y2 = self.tdshape1(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af2(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af2)
y1 = y[:, 0:1, :]
y2 = self.tdshape2(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af3(y, cf, debug=debug)
cf = self.feature_transform(cf, self.post_af3)
y1 = y[:, 0:1, :]
y2 = self.tdshape3(y[:, 1:2, :], cf)
y = torch.cat((y1, y2), dim=1)
y = self.af4(y, cf, debug=debug)
return y

View File

@@ -0,0 +1,68 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import math as m
import torch
from torch import nn
class ScaleEmbedding(nn.Module):
def __init__(self,
dim,
min_val,
max_val,
logscale=False):
super().__init__()
if min_val >= max_val:
raise ValueError('min_val must be smaller than max_val')
if min_val <= 0 and logscale:
raise ValueError('min_val must be positive when logscale is true')
self.dim = dim
self.logscale = logscale
self.min_val = min_val
self.max_val = max_val
if logscale:
self.min_val = m.log(self.min_val)
self.max_val = m.log(self.max_val)
self.offset = (self.min_val + self.max_val) / 2
self.scale_factors = nn.Parameter(
torch.arange(1, dim+1, dtype=torch.float32) * torch.pi / (self.max_val - self.min_val)
)
def forward(self, x):
if self.logscale: x = torch.log(x)
x = torch.clip(x, self.min_val, self.max_val) - self.offset
return torch.sin(x.unsqueeze(-1) * self.scale_factors - 0.5)

View File

@@ -0,0 +1,179 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils.layers.silk_upsampler import SilkUpsampler
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.layers.deemph import Deemph
from utils.misc import freeze_model
from models.nns_base import NNSBase
from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
class ShapeUp48(NNSBase):
FRAME_SIZE16k=80
def __init__(self,
num_features=47,
pitch_embedding_dim=64,
cond_dim=256,
pitch_max=257,
kernel_size=15,
preemph=0.85,
skip=288,
conv_gain_limits_db=[-6, 6],
numbits_range=[50, 650],
numbits_embedding_dim=8,
hidden_feature_dim=64,
partial_lookahead=True,
norm_p=2,
target_fs=48000,
noise_amplitude=0,
prenet=None,
avg_pool_k=4):
super().__init__(skip=skip, preemph=preemph)
self.num_features = num_features
self.cond_dim = cond_dim
self.pitch_max = pitch_max
self.pitch_embedding_dim = pitch_embedding_dim
self.kernel_size = kernel_size
self.preemph = preemph
self.skip = skip
self.numbits_range = numbits_range
self.numbits_embedding_dim = numbits_embedding_dim
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
self.frame_size48 = int(self.FRAME_SIZE16k * target_fs / 16000 + .1)
self.frame_size32 = self.FRAME_SIZE16k * 2
self.noise_amplitude = noise_amplitude
self.prenet = prenet
# freeze prenet if given
if prenet is not None:
freeze_model(self.prenet)
try:
self.deemph = Deemph(prenet.preemph)
except:
print("[warning] prenet model is expected to have preemph attribute")
self.deemph = Deemph(0)
# upsampler
self.upsampler = SilkUpsampler()
# pitch embedding
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
# numbits embedding
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
# feature net
if partial_lookahead:
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
# non-linear transforms
self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k)
self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k)
# spectral shaping
self.af_noise = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=[-30, 0], norm_p=norm_p)
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af2 = LimitedAdaptiveConv1d(3, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE16k
# feature net
feature_net_flops = self.feature_net.flop_count(frame_rate)
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(2 * rate) + self.af3.flop_count(3 * rate)
if verbose:
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
return feature_net_flops + af_flops
def forward(self, x, features, periods, numbits, debug=False):
if self.prenet is not None:
with torch.no_grad():
x = self.prenet(x, features, periods, numbits)
x = self.deemph(x)
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
cf = self.feature_net(full_features)
y32 = self.upsampler.hq_2x_up(x)
noise = self.noise_amplitude * torch.randn_like(y32)
noise = self.af_noise(noise, cf)
y32 = self.af1(y32, cf, debug=debug)
y32_1 = y32[:, 0:1, :]
y32_2 = self.tdshape1(y32[:, 1:2, :], cf)
y32 = torch.cat((y32_1, y32_2, noise), dim=1)
y32 = self.af2(y32, cf, debug=debug)
y48 = self.upsampler.interpolate_3_2(y32)
y48_1 = y48[:, 0:1, :]
y48_2 = self.tdshape2(y48[:, 1:2, :], cf)
y48 = torch.cat((y48_1, y48_2), dim=1)
y48 = self.af3(y48, cf, debug=debug)
return y48

View File

@@ -0,0 +1,86 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
class SilkFeatureNet(nn.Module):
def __init__(self,
feature_dim=47,
num_channels=256,
lookahead=False):
super(SilkFeatureNet, self).__init__()
self.feature_dim = feature_dim
self.num_channels = num_channels
self.lookahead = lookahead
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
self.conv2 = nn.Conv1d(num_channels, num_channels, 3, dilation=2)
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
def flop_count(self, rate=200):
count = 0
for conv in self.conv1, self.conv2:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
return count
def forward(self, features, state=None):
""" features shape: (batch_size, num_frames, feature_dim) """
batch_size = features.size(0)
if state is None:
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
features = features.permute(0, 2, 1)
if self.lookahead:
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
c = torch.tanh(self.conv2(F.pad(c, [2, 2])))
else:
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
c = torch.tanh(self.conv2(F.pad(c, [4, 0])))
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)
return c

View File

@@ -0,0 +1,127 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import sys
sys.path.append('../dnntools')
import numbers
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from utils.complexity import _conv1d_flop_count
from dnntools.quantization.softquant import soft_quant
from dnntools.sparsification import mark_for_sparsification
class SilkFeatureNetPL(nn.Module):
""" feature net with partial lookahead """
def __init__(self,
feature_dim=47,
num_channels=256,
hidden_feature_dim=64,
softquant=False,
sparsify=True,
sparsification_density=0.5,
apply_weight_norm=False):
super(SilkFeatureNetPL, self).__init__()
if isinstance(sparsification_density, numbers.Number):
sparsification_density = 4 * [sparsification_density]
self.feature_dim = feature_dim
self.num_channels = num_channels
self.hidden_feature_dim = hidden_feature_dim
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1))
self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2))
self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
self.gru = norm(norm(nn.GRU(num_channels, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
if softquant:
self.conv2 = soft_quant(self.conv2)
self.tconv = soft_quant(self.tconv)
self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
if sparsify:
mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4]))
mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
mark_for_sparsification(
self.gru,
{
'W_ir' : (sparsification_density[2], [8, 4], False),
'W_iz' : (sparsification_density[2], [8, 4], False),
'W_in' : (sparsification_density[2], [8, 4], False),
'W_hr' : (sparsification_density[3], [8, 4], True),
'W_hz' : (sparsification_density[3], [8, 4], True),
'W_hn' : (sparsification_density[3], [8, 4], True),
}
)
def flop_count(self, rate=200):
count = 0
for conv in self.conv1, self.conv2, self.tconv:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
return count
def forward(self, features, state=None):
""" features shape: (batch_size, num_frames, feature_dim) """
batch_size = features.size(0)
num_frames = features.size(1)
if state is None:
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
features = features.permute(0, 2, 1)
# dimensionality reduction
c = torch.tanh(self.conv1(features))
# frame accumulation
c = c.permute(0, 2, 1)
c = c.reshape(batch_size, num_frames // 4, -1).permute(0, 2, 1)
c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
# upsampling
c = torch.tanh(self.tconv(c))
c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)
return c

View File

@@ -0,0 +1,9 @@
pyyaml==6.0.1
torch==2.0.1
numpy==1.25.2
scipy==1.11.2
pesq==0.0.4
gitpython==3.1.41
matplotlib==3.7.3
torchaudio==2.0.2
tqdm==4.66.1

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,85 @@
import os
import argparse
import numpy as np
from scipy import signal
from scipy.io import wavfile
import resampy
parser = argparse.ArgumentParser()
parser.add_argument("filelist", type=str, help="file with filenames for concatenation in WAVE format")
parser.add_argument("target_fs", type=int, help="target sampling rate of concatenated file")
parser.add_argument("output", type=str, help="binary output file (integer16)")
parser.add_argument("--basedir", type=str, help="basedir for filenames in filelist, defaults to ./", default="./")
parser.add_argument("--normalize", action="store_true", help="apply normalization")
parser.add_argument("--db_max", type=float, help="max DB for random normalization", default=0)
parser.add_argument("--db_min", type=float, help="min DB for random normalization", default=0)
parser.add_argument("--verbose", action="store_true")
def read_filelist(basedir, filelist):
with open(filelist, "r") as f:
files = f.readlines()
fullfiles = [os.path.join(basedir, f.rstrip('\n')) for f in files if len(f.rstrip('\n')) > 0]
return fullfiles
def read_wave(file, target_fs):
fs, x = wavfile.read(file)
if fs < target_fs:
return None
print(f"[read_wave] warning: file {file} will be up-sampled from {fs} to {target_fs} Hz")
if fs != target_fs:
x = resampy.resample(x, fs, target_fs)
return x.astype(np.float32)
def random_normalize(x, db_min, db_max, max_val=2**15 - 1):
db = np.random.uniform(db_min, db_max, 1)
m = np.abs(x).max()
c = 10**(db/20) * max_val / m
return c * x
def concatenate(filelist : str, output : str, target_fs: int, normalize=True, db_min=0, db_max=0, verbose=False):
overlap_size = int(40 * target_fs / 8000)
overlap_mem = np.zeros(overlap_size, dtype=np.float32)
overlap_win1 = (0.5 + 0.5 * np.cos(np.arange(0, overlap_size) * np.pi / overlap_size)).astype(np.float32)
overlap_win2 = np.flipud(overlap_win1)
with open(output, 'wb') as f:
for file in filelist:
x = read_wave(file, target_fs)
if x is None: continue
if len(x) < 10 * overlap_size:
if verbose: print(f"skipping {file}...")
continue
elif verbose:
print(f"processing {file}...")
if normalize:
x = random_normalize(x, db_min, db_max)
x1 = x[:-overlap_size]
x1[:overlap_size] = overlap_win1 * overlap_mem + overlap_win2 * x1[:overlap_size]
f.write(x1.astype(np.int16).tobytes())
overlap_mem = x1[-overlap_size]
if __name__ == "__main__":
args = parser.parse_args()
filelist = read_filelist(args.basedir, args.filelist)
concatenate(filelist, args.output, args.target_fs, normalize=args.normalize, db_min=args.db_min, db_max=args.db_max, verbose=args.verbose)

View File

@@ -0,0 +1,28 @@
import argparse
from scipy.io import wavfile
import torch
import numpy as np
from utils.layers.silk_upsampler import SilkUpsampler
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="input wave file")
parser.add_argument("output", type=str, help="output wave file")
if __name__ == "__main__":
args = parser.parse_args()
fs, x = wavfile.read(args.input)
# being lazy for now
assert fs == 16000 and x.dtype == np.int16
x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)
upsampler = SilkUpsampler()
y = upsampler(x)
y = y.squeeze().numpy().astype(np.int16)
wavfile.write(args.output, 48000, y[13:])

View File

@@ -0,0 +1,123 @@
import argparse
import os
import yaml
import subprocess
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('commonvoice_base_dir')
parser.add_argument('output_dir')
parser.add_argument('--clips-per-language', required=False, type=int, default=10)
parser.add_argument('--seed', required=False, type=int, default=2024)
def select_clips(dir, num_clips=10):
if num_clips % 2:
print(f"warning: number of clips will be reduced to {num_clips - 1}")
female = dict()
male = dict()
clips = np.genfromtxt(os.path.join(dir, 'validated.tsv'), delimiter='\t', dtype=str, invalid_raise=False)
clips_by_client = dict()
if len(clips.shape) < 2 or len(clips) < num_clips:
# not enough data to proceed
return None
for client in set(clips[1:,0]):
client_clips = clips[clips[:, 0] == client]
f, m = False, False
if 'female_feminine' in client_clips[:, 8]:
female[client] = client_clips[client_clips[:, 8] == 'female_feminine']
f = True
if 'male_masculine' in client_clips[:, 8]:
male[client] = client_clips[client_clips[:, 8] == 'male_masculine']
m = True
if f and m:
print(f"both male and female clips under client {client}")
if min(len(female), len(male)) < num_clips // 2:
return None
# select num_clips // 2 random female clients
female_client_selection = np.array(list(female.keys()), dtype=str)[np.random.choice(len(female), num_clips//2, replace=False)]
female_clip_selection = []
for c in female_client_selection:
s_idx = np.random.randint(0, len(female[c]))
female_clip_selection.append(os.path.join(dir, 'clips', female[c][s_idx, 1].item()))
# select num_clips // 2 random female clients
male_client_selection = np.array(list(male.keys()), dtype=str)[np.random.choice(len(male), num_clips//2, replace=False)]
male_clip_selection = []
for c in male_client_selection:
s_idx = np.random.randint(0, len(male[c]))
male_clip_selection.append(os.path.join(dir, 'clips', male[c][s_idx, 1].item()))
return female_clip_selection + male_clip_selection
def ffmpeg_available():
try:
x = subprocess.run(['ffmpeg', '-h'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return x.returncode == 0
except:
return False
def convert_clips(selection, outdir):
if not ffmpeg_available():
raise RuntimeError("ffmpeg not available")
clipdir = os.path.join(outdir, 'clips')
os.makedirs(clipdir, exist_ok=True)
clipdict = dict()
for lang, clips in selection.items():
clipdict[lang] = []
for clip in clips:
clipname = os.path.splitext(os.path.split(clip)[-1])[0]
target_name = os.path.join('clips', clipname + '.wav')
call_args = ['ffmpeg', '-i', clip, '-ar', '16000', os.path.join(outdir, target_name)]
print(call_args)
r = subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
if r.returncode != 0:
raise RuntimeError(f'could not execute {call_args}')
clipdict[lang].append(target_name)
return clipdict
if __name__ == "__main__":
if not ffmpeg_available():
raise RuntimeError("ffmpeg not available")
args = parser.parse_args()
base_dir = args.commonvoice_base_dir
output_dir = args.output_dir
seed = args.seed
np.random.seed(seed)
langs = os.listdir(base_dir)
selection = dict()
for lang in langs:
print(f"processing {lang}...")
clips = select_clips(os.path.join(base_dir, lang))
if clips is not None:
selection[lang] = clips
os.makedirs(output_dir, exist_ok=True)
clips = convert_clips(selection, output_dir)
with open(os.path.join(output_dir, 'clips.yml'), 'w') as f:
yaml.dump(clips, f)

View File

@@ -0,0 +1,25 @@
#!/bin/bash
INPUT="dataset/LibriSpeech"
OUTPUT="testdata"
OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
BITRATES=( 6000 7500 ) # 9000 12000 15000 18000 24000 32000 )
mkdir -p $OUTPUT
for fn in $(find $INPUT -name "*.wav")
do
name=$(basename ${fn%*.wav})
sox $fn -r 16000 -b 16 -e signed-integer ${OUTPUT}/tmp.raw
for br in ${BITRATES[@]}
do
folder=${OUTPUT}/"${name}_${br}.se"
echo "creating ${folder}..."
mkdir -p $folder
cp ${OUTPUT}/tmp.raw ${folder}/clean.s16
(cd ${folder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
done
rm -f ${OUTPUT}/tmp.raw
done

View File

@@ -0,0 +1,7 @@
#!/bin/bash
export PYTHON=/home/ubuntu/opt/miniconda3/envs/torch/bin/python
export LACE="/local/experiments/ietf_enhancement_studies/checkpoints/lace_checkpoint.pth"
export NOLACE="/local/experiments/ietf_enhancement_studies/checkpoints/nolace_checkpoint.pth"
export TESTMODEL="/local/experiments/ietf_enhancement_studies/opus/dnn/torch/osce/test_model.py"
export OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"

View File

@@ -0,0 +1,113 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
from scipy.io import wavfile
from pesq import pesq
import numpy as np
from moc import compare
from moc2 import compare as compare2
#from warpq import compute_WAPRQ as warpq
from lace_loss_metric import compare as laceloss_compare
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='folder with processed items')
parser.add_argument('metric', type=str, choices=['pesq', 'moc', 'moc2', 'laceloss'], help='metric to be used for evaluation')
def get_bitrates(folder):
with open(os.path.join(folder, 'bitrates.txt')) as f:
x = f.read()
bitrates = [int(y) for y in x.rstrip('\n').split()]
return bitrates
def get_itemlist(folder):
with open(os.path.join(folder, 'items.txt')) as f:
lines = f.readlines()
items = [x.split()[0] for x in lines]
return items
def process_item(folder, item, bitrate, metric):
fs, x_clean = wavfile.read(os.path.join(folder, 'clean', f"{item}_{bitrate}_clean.wav"))
fs, x_opus = wavfile.read(os.path.join(folder, 'opus', f"{item}_{bitrate}_opus.wav"))
fs, x_lace = wavfile.read(os.path.join(folder, 'lace', f"{item}_{bitrate}_lace.wav"))
fs, x_nolace = wavfile.read(os.path.join(folder, 'nolace', f"{item}_{bitrate}_nolace.wav"))
x_clean = x_clean.astype(np.float32) / 2**15
x_opus = x_opus.astype(np.float32) / 2**15
x_lace = x_lace.astype(np.float32) / 2**15
x_nolace = x_nolace.astype(np.float32) / 2**15
if metric == 'pesq':
result = [pesq(fs, x_clean, x_opus), pesq(fs, x_clean, x_lace), pesq(fs, x_clean, x_nolace)]
elif metric =='moc':
result = [compare(x_clean, x_opus), compare(x_clean, x_lace), compare(x_clean, x_nolace)]
elif metric =='moc2':
result = [compare2(x_clean, x_opus), compare2(x_clean, x_lace), compare2(x_clean, x_nolace)]
# elif metric == 'warpq':
# result = [warpq(x_clean, x_opus), warpq(x_clean, x_lace), warpq(x_clean, x_nolace)]
elif metric == 'laceloss':
result = [laceloss_compare(x_clean, x_opus), laceloss_compare(x_clean, x_lace), laceloss_compare(x_clean, x_nolace)]
else:
raise ValueError(f'unknown metric {metric}')
return result
def process_bitrate(folder, items, bitrate, metric):
results = np.zeros((len(items), 3))
for i, item in enumerate(items):
results[i, :] = np.array(process_item(folder, item, bitrate, metric))
return results
if __name__ == "__main__":
args = parser.parse_args()
items = get_itemlist(args.folder)
bitrates = get_bitrates(args.folder)
results = dict()
for br in bitrates:
print(f"processing bitrate {br}...")
results[br] = process_bitrate(args.folder, items, br, args.metric)
np.save(os.path.join(args.folder, f'results_{args.metric}.npy'), results)
print("Done.")

View File

@@ -0,0 +1,330 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
"""STFT-based Loss modules."""
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import torchaudio
def get_window(win_name, win_length, *args, **kwargs):
window_dict = {
'bartlett_window' : torch.bartlett_window,
'blackman_window' : torch.blackman_window,
'hamming_window' : torch.hamming_window,
'hann_window' : torch.hann_window,
'kaiser_window' : torch.kaiser_window
}
if not win_name in window_dict:
raise ValueError()
return window_dict[win_name](win_length, *args, **kwargs)
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
win = get_window(window, win_length).to(x.device)
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
return torch.clamp(torch.abs(x_stft), min=1e-7)
def spectral_convergence_loss(Y_true, Y_pred):
dims=list(range(1, len(Y_pred.shape)))
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
def log_magnitude_loss(Y_true, Y_pred):
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
def spectral_xcorr_loss(Y_true, Y_pred):
Y_true = Y_true.abs()
Y_pred = Y_pred.abs()
dims=list(range(1, len(Y_pred.shape)))
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
return 1 - xcorr.mean()
class MRLogMelLoss(nn.Module):
def __init__(self,
fft_sizes=[512, 256, 128, 64],
overlap=0.5,
fs=16000,
n_mels=18
):
self.fft_sizes = fft_sizes
self.overlap = overlap
self.fs = fs
self.n_mels = n_mels
super().__init__()
self.mel_specs = []
for fft_size in fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
n_mels = self.n_mels
if fft_size < 128:
n_mels //= 2
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
for i, mel_spec in enumerate(self.mel_specs):
self.add_module(f'mel_spec_{i+1}', mel_spec)
def forward(self, y_true, y_pred):
loss = torch.zeros(1, device=y_true.device)
for mel_spec in self.mel_specs:
Y_true = mel_spec(y_true)
Y_pred = mel_spec(y_pred)
loss = loss + log_magnitude_loss(Y_true, Y_pred)
loss = loss / len(self.mel_specs)
return loss
def create_weight_matrix(num_bins, bins_per_band=10):
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
r0 = bins_per_band // 2
r1 = bins_per_band - r0
for i in range(num_bins):
i0 = max(i - r0, 0)
j0 = min(i + r1, num_bins)
m[i, i0: j0] += 1
if i < r0:
m[i, :r0 - i] += 1
if i > num_bins - r1:
m[i, num_bins - r1 - i:] += 1
return m / bins_per_band
def weighted_spectral_convergence(Y_true, Y_pred, w):
# calculate sfm based weights
logY = torch.log(torch.abs(Y_true) + 1e-9)
Y = torch.abs(Y_true)
avg_logY = torch.matmul(logY.transpose(1, 2), w)
avg_Y = torch.matmul(Y.transpose(1, 2), w)
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
loss = torch.mean(
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
)
return loss
def gen_filterbank(N, Fs=16000):
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
ERB_N = 24.7 + .108*in_freq
delta = np.abs(in_freq-out_freq)/ERB_N
center = (delta<.5).astype('float32')
R = -12*center*delta**2 + (1-center)*(3-12*delta)
RE = 10.**(R/10.)
norm = np.sum(RE, axis=1)
RE = RE/norm[:, np.newaxis]
return torch.from_numpy(RE)
def smooth_log_mag(Y_true, Y_pred, filterbank):
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
loss = torch.abs(
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
)
loss = loss.mean()
return loss
class MRSTFTLoss(nn.Module):
def __init__(self,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
overlap=0.5,
window='hann_window',
fs=16000,
log_mag_weight=0,
sc_weight=0,
wsc_weight=0,
smooth_log_mag_weight=2,
sxcorr_weight=1):
super().__init__()
self.fft_sizes = fft_sizes
self.overlap = overlap
self.window = window
self.log_mag_weight = log_mag_weight
self.sc_weight = sc_weight
self.wsc_weight = wsc_weight
self.smooth_log_mag_weight = smooth_log_mag_weight
self.sxcorr_weight = sxcorr_weight
self.fs = fs
# weights for SFM weighted spectral convergence loss
self.wsc_weights = torch.nn.ParameterDict()
for fft_size in fft_sizes:
width = min(11, int(1000 * fft_size / self.fs + .5))
width += width % 2
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
create_weight_matrix(fft_size // 2 + 1, width),
requires_grad=False
)
# filterbanks for smooth log magnitude loss
self.filterbanks = torch.nn.ParameterDict()
for fft_size in fft_sizes:
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
gen_filterbank(fft_size//2),
requires_grad=False
)
def __call__(self, y_true, y_pred):
lm_loss = torch.zeros(1, device=y_true.device)
sc_loss = torch.zeros(1, device=y_true.device)
wsc_loss = torch.zeros(1, device=y_true.device)
slm_loss = torch.zeros(1, device=y_true.device)
sxcorr_loss = torch.zeros(1, device=y_true.device)
for fft_size in self.fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
win_size = fft_size
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
if self.log_mag_weight > 0:
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
if self.sc_weight > 0:
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
if self.wsc_weight > 0:
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
if self.smooth_log_mag_weight > 0:
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
if self.sxcorr_weight > 0:
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
return total_loss
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
class LaceLoss(nn.Module):
def __init__(self):
super().__init__()
self.stftloss = MRSTFTLoss(log_mag_weight=0, sc_weight=0, wsc_weight=0, smooth_log_mag_weight=2, sxcorr_weight=1)
def forward(self, x, y):
specloss = self.stftloss(x, y)
phaseloss = td_l2_norm(x, y)
total_loss = (specloss + 10 * phaseloss) / 13
return total_loss
def compare(self, x_ref, x_deg):
# trim items to same size
n = min(len(x_ref), len(x_deg))
x_ref = x_ref[:n].copy()
x_deg = x_deg[:n].copy()
# pre-emphasis
x_ref[1:] -= 0.85 * x_ref[:-1]
x_deg[1:] -= 0.85 * x_deg[:-1]
device = next(iter(self.parameters())).device
x = torch.from_numpy(x_ref).to(device)
y = torch.from_numpy(x_deg).to(device)
with torch.no_grad():
dist = 10 * self.forward(x, y)
return dist.cpu().numpy().item()
lace_loss = LaceLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lace_loss.to(device)
def compare(x, y):
return lace_loss.compare(x, y)

View File

@@ -0,0 +1,116 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def plot_data(filename, data, title=None):
compare_dict = dict()
for br in data.keys():
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
compare_dict[f'LACE {br/1000:.1f} kb/s'] = data[br][:, 1]
compare_dict[f'NoLACE {br/1000:.1f} kb/s'] = data[br][:, 2]
plt.rcParams.update({
"text.usetex": True,
"font.family": "Helvetica",
"font.size": 32
})
black = '#000000'
red = '#ff5745'
blue = '#007dbc'
colors = [black, red, blue]
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
Patch(facecolor=colors[1], label='LACE'),
Patch(facecolor=colors[2], label='NoLACE')]
fig, ax = plt.subplots()
fig.set_size_inches(40, 20)
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
for i, patch in enumerate(bplot['boxes']):
patch.set_facecolor(colors[i%3])
ax.set_xticklabels(compare_dict.keys(), rotation=290)
if title is not None:
ax.set_title(title)
ax.legend(handles=legend_elements)
fig.savefig(filename, bbox_inches='tight')
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
print("Done.")

View File

@@ -0,0 +1,109 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def plot_data(filename, data, title=None):
compare_dict = dict()
for br in data.keys():
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
compare_dict[f'LACE (MOC only) {br/1000:.1f} kb/s'] = data[br][:, 1]
compare_dict[f'LACE (MOC + TD) {br/1000:.1f} kb/s'] = data[br][:, 2]
plt.rcParams.update({
"text.usetex": True,
"font.family": "Helvetica",
"font.size": 32
})
colors = ['pink', 'lightblue', 'lightgreen']
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
Patch(facecolor=colors[1], label='MOC loss only'),
Patch(facecolor=colors[2], label='MOC + TD loss')]
fig, ax = plt.subplots()
fig.set_size_inches(40, 20)
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
for i, patch in enumerate(bplot['boxes']):
patch.set_facecolor(colors[i%3])
ax.set_xticklabels(compare_dict.keys(), rotation=290)
if title is not None:
ax.set_title(title)
ax.legend(handles=legend_elements)
fig.savefig(filename, bbox_inches='tight')
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
print("Done.")

View File

@@ -0,0 +1,124 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def make_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1]
nolace = data[br][:, 2]
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
def make_diff_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1] - opus
nolace = data[br][:, 2] - opus
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
print("Done.")

View File

@@ -0,0 +1,121 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def make_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1]
nolace = data[br][:, 2]
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
def make_diff_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1] - opus
nolace = data[br][:, 2] - opus
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
print("Done.")

View File

@@ -0,0 +1,182 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numpy as np
import scipy.signal
def compute_vad_mask(x, fs, stop_db=-70):
frame_length = (fs + 49) // 50
x = x[: frame_length * (len(x) // frame_length)]
frames = x.reshape(-1, frame_length)
frame_energy = np.sum(frames ** 2, axis=1)
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
vactive = np.ones_like(frames)
vactive[frame_energy_smooth < max_threshold, :] = 0
vactive = vactive.reshape(-1)
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
filter = filter / filter.sum()
mask = np.convolve(vactive, filter, mode='same')
return x, mask
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
num_samples = frame_size + (num_frames - 1) * hop_size
if len(mask) < num_samples:
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
else:
mask = mask[:num_samples]
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
return new_mask
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
num_spectra = (len(x) - window_size - hop_size) // hop_size
window = scipy.signal.get_window(window, window_size)
N = window_size // 2
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
return psd
def frequency_mask(num_bands, up_factor, down_factor):
up_mask = np.zeros((num_bands, num_bands))
down_mask = np.zeros((num_bands, num_bands))
for i in range(num_bands):
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
return down_mask @ up_mask
def rect_fb(band_limits, num_bins=None):
num_bands = len(band_limits) - 1
if num_bins is None:
num_bins = band_limits[-1]
fb = np.zeros((num_bands, num_bins))
for i in range(num_bands):
fb[i, band_limits[i]:band_limits[i+1]] = 1
return fb
def compare(x, y, apply_vad=False):
""" Modified version of opus_compare for 16 kHz mono signals
Args:
x (np.ndarray): reference input signal scaled to [-1, 1]
y (np.ndarray): test signal scaled to [-1, 1]
Returns:
float: perceptually weighted error
"""
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
num_bands = len(band_limits) - 1
fb = rect_fb(band_limits, num_bins=81)
# trim samples to same size
num_samples = min(len(x), len(y))
x = x[:num_samples] * 2**15
y = y[:num_samples] * 2**15
psd_x = power_spectrum(x) + 100000
psd_y = power_spectrum(y) + 100000
num_frames = psd_x.shape[0]
# average band energies
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
# frequecy masking
f_mask = frequency_mask(num_bands, 0.1, 0.03)
mask_x = be_x @ f_mask.T
# temporal masking
for i in range(1, num_frames):
mask_x[i, :] += 0.5 * mask_x[i-1, :]
# apply mask
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
# 2-frame average
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
# distortion metric
re = masked_psd_y / masked_psd_x
im = np.log(re) ** 2
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
Ef = np.mean(Eb , axis=1)
if apply_vad:
_, mask = compute_vad_mask(x, 16000)
mask = convert_mask(mask, Ef.shape[0])
else:
mask = np.ones_like(Ef)
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
return float(err)
if __name__ == "__main__":
import argparse
from scipy.io import wavfile
parser = argparse.ArgumentParser()
parser.add_argument('ref', type=str, help='reference wav file')
parser.add_argument('deg', type=str, help='degraded wav file')
parser.add_argument('--apply-vad', action='store_true')
args = parser.parse_args()
fs1, x = wavfile.read(args.ref)
fs2, y = wavfile.read(args.deg)
if max(fs1, fs2) != 16000:
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
x = x.astype(np.float32) / 2**15
y = y.astype(np.float32) / 2**15
err = compare(x, y, apply_vad=args.apply_vad)
print(f"MOC: {err}")

View File

@@ -0,0 +1,190 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numpy as np
import scipy.signal
def compute_vad_mask(x, fs, stop_db=-70):
frame_length = (fs + 49) // 50
x = x[: frame_length * (len(x) // frame_length)]
frames = x.reshape(-1, frame_length)
frame_energy = np.sum(frames ** 2, axis=1)
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
vactive = np.ones_like(frames)
vactive[frame_energy_smooth < max_threshold, :] = 0
vactive = vactive.reshape(-1)
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
filter = filter / filter.sum()
mask = np.convolve(vactive, filter, mode='same')
return x, mask
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
num_samples = frame_size + (num_frames - 1) * hop_size
if len(mask) < num_samples:
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
else:
mask = mask[:num_samples]
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
return new_mask
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
num_spectra = (len(x) - window_size - hop_size) // hop_size
window = scipy.signal.get_window(window, window_size)
N = window_size // 2
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
return psd
def frequency_mask(num_bands, up_factor, down_factor):
up_mask = np.zeros((num_bands, num_bands))
down_mask = np.zeros((num_bands, num_bands))
for i in range(num_bands):
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
return down_mask @ up_mask
def rect_fb(band_limits, num_bins=None):
num_bands = len(band_limits) - 1
if num_bins is None:
num_bins = band_limits[-1]
fb = np.zeros((num_bands, num_bins))
for i in range(num_bands):
fb[i, band_limits[i]:band_limits[i+1]] = 1
return fb
def _compare(x, y, apply_vad=False, factor=1):
""" Modified version of opus_compare for 16 kHz mono signals
Args:
x (np.ndarray): reference input signal scaled to [-1, 1]
y (np.ndarray): test signal scaled to [-1, 1]
Returns:
float: perceptually weighted error
"""
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
band_limits = [factor * b for b in [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]]
window_size = factor * 160
hop_size = factor * 40
num_bins = window_size // 2 + 1
num_bands = len(band_limits) - 1
fb = rect_fb(band_limits, num_bins=num_bins)
# trim samples to same size
num_samples = min(len(x), len(y))
x = x[:num_samples].copy() * 2**15
y = y[:num_samples].copy() * 2**15
psd_x = power_spectrum(x, window_size=window_size, hop_size=hop_size) + 100000
psd_y = power_spectrum(y, window_size=window_size, hop_size=hop_size) + 100000
num_frames = psd_x.shape[0]
# average band energies
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
# frequecy masking
f_mask = frequency_mask(num_bands, 0.1, 0.03)
mask_x = be_x @ f_mask.T
# temporal masking
for i in range(1, num_frames):
mask_x[i, :] += (0.5 ** factor) * mask_x[i-1, :]
# apply mask
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
# 2-frame average
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
# distortion metric
re = masked_psd_y / masked_psd_x
#im = re - np.log(re) - 1
im = np.log(re) ** 2
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
Ef = np.mean(Eb ** 1, axis=1)
if apply_vad:
_, mask = compute_vad_mask(x, 16000)
mask = convert_mask(mask, Ef.shape[0])
else:
mask = np.ones_like(Ef)
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
return float(err)
def compare(x, y, apply_vad=False):
err = np.linalg.norm([_compare(x, y, apply_vad=apply_vad, factor=1)], ord=2)
return err
if __name__ == "__main__":
import argparse
from scipy.io import wavfile
parser = argparse.ArgumentParser()
parser.add_argument('ref', type=str, help='reference wav file')
parser.add_argument('deg', type=str, help='degraded wav file')
parser.add_argument('--apply-vad', action='store_true')
args = parser.parse_args()
fs1, x = wavfile.read(args.ref)
fs2, y = wavfile.read(args.deg)
if max(fs1, fs2) != 16000:
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
x = x.astype(np.float32) / 2**15
y = y.astype(np.float32) / 2**15
err = compare(x, y, apply_vad=args.apply_vad)
print(f"MOC: {err}")

View File

@@ -0,0 +1,98 @@
#!/bin/bash
if [ ! -f "$PYTHON" ]
then
echo "PYTHON variable does not link to a file. Please point it to your python executable."
exit 1
fi
if [ ! -f "$TESTMODEL" ]
then
echo "TESTMODEL variable does not link to a file. Please point it to your copy of test_model.py"
exit 1
fi
if [ ! -f "$OPUSDEMO" ]
then
echo "OPUSDEMO variable does not link to a file. Please point it to your patched version of opus_demo."
exit 1
fi
if [ ! -f "$LACE" ]
then
echo "LACE variable does not link to a file. Please point it to your copy of the LACE checkpoint."
exit 1
fi
if [ ! -f "$NOLACE" ]
then
echo "LACE variable does not link to a file. Please point it to your copy of the NOLACE checkpoint."
exit 1
fi
case $# in
2) INPUT=$1; OUTPUT=$2;;
*) echo "process_dataset.sh <input folder> <output folder>"; exit 1;;
esac
if [ -d $OUTPUT ]
then
echo "output folder $OUTPUT exists, aborting..."
exit 1
fi
mkdir -p $OUTPUT
if [ "$BITRATES" == "" ]
then
BITRATES=( 6000 7500 9000 12000 15000 18000 24000 32000 )
echo "BITRATES variable not defined. Proceeding with default bitrates ${BITRATES[@]}."
fi
echo "LACE=${LACE}" > ${OUTPUT}/info.txt
echo "NOLACE=${NOLACE}" >> ${OUTPUT}/info.txt
ITEMFILE=${OUTPUT}/items.txt
BITRATEFILE=${OUTPUT}/bitrates.txt
FPROCESSING=${OUTPUT}/processing
FCLEAN=${OUTPUT}/clean
FOPUS=${OUTPUT}/opus
FLACE=${OUTPUT}/lace
FNOLACE=${OUTPUT}/nolace
mkdir -p $FPROCESSING $FCLEAN $FOPUS $FLACE $FNOLACE
echo "${BITRATES[@]}" > $BITRATEFILE
for fn in $(find $INPUT -type f -name "*.wav")
do
UUID=$(uuid)
echo "$UUID $fn" >> $ITEMFILE
PIDS=( )
for br in ${BITRATES[@]}
do
# run opus
pfolder=${FPROCESSING}/${UUID}_${br}
mkdir -p $pfolder
sox $fn -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16
(cd ${pfolder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
# copy clean and opus
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16 $FCLEAN/${UUID}_${br}_clean.wav
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/noisy.s16 $FOPUS/${UUID}_${br}_opus.wav
# run LACE
$PYTHON $TESTMODEL $pfolder $LACE $FLACE/${UUID}_${br}_lace.wav &
PIDS+=( "$!" )
# run NoLACE
$PYTHON $TESTMODEL $pfolder $NOLACE $FNOLACE/${UUID}_${br}_nolace.wav &
PIDS+=( "$!" )
done
for pid in ${PIDS[@]}
do
wait $pid
done
done

View File

@@ -0,0 +1,138 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import tempfile
import shutil
import pandas as pd
from scipy.spatial.distance import cdist
from scipy.io import wavfile
import numpy as np
from nomad_audio.nomad import Nomad
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='folder with processed items')
parser.add_argument('--full-reference', action='store_true', help='use NOMAD as full-reference metric')
parser.add_argument('--device', type=str, default=None, help='device for Nomad')
def get_bitrates(folder):
with open(os.path.join(folder, 'bitrates.txt')) as f:
x = f.read()
bitrates = [int(y) for y in x.rstrip('\n').split()]
return bitrates
def get_itemlist(folder):
with open(os.path.join(folder, 'items.txt')) as f:
lines = f.readlines()
items = [x.split()[0] for x in lines]
return items
def nomad_wrapper(ref_folder, deg_folder, full_reference=False, ref_embeddings=None, device=None):
model = Nomad(device=device)
if not full_reference:
results = model.predict(nmr=ref_folder, deg=deg_folder)[0].to_dict()['NOMAD']
return results, None
else:
if ref_embeddings is None:
print(f"Computing reference embeddings from {ref_folder}")
ref_data = pd.DataFrame(sorted(os.listdir(ref_folder)))
ref_data.columns = ['filename']
ref_data['filename'] = [os.path.join(ref_folder, x) for x in ref_data['filename']]
ref_embeddings = model.get_embeddings_csv(model.model, ref_data).set_index('filename')
print(f"Computing degraded embeddings from {deg_folder}")
deg_data = pd.DataFrame(sorted(os.listdir(deg_folder)))
deg_data.columns = ['filename']
deg_data['filename'] = [os.path.join(deg_folder, x) for x in deg_data['filename']]
deg_embeddings = model.get_embeddings_csv(model.model, deg_data).set_index('filename')
dist = np.diag(cdist(ref_embeddings, deg_embeddings)) # wasteful
test_files = [x.split('/')[-1].split('.')[0] for x in deg_embeddings.index]
results = dict(zip(test_files, dist))
return results, ref_embeddings
def nomad_process_all(folder, full_reference=False, device=None):
bitrates = get_bitrates(folder)
items = get_itemlist(folder)
with tempfile.TemporaryDirectory() as dir:
cleandir = os.path.join(dir, 'clean')
opusdir = os.path.join(dir, 'opus')
lacedir = os.path.join(dir, 'lace')
nolacedir = os.path.join(dir, 'nolace')
# prepare files
for d in [cleandir, opusdir, lacedir, nolacedir]: os.makedirs(d)
for br in bitrates:
for item in items:
for cond in ['clean', 'opus', 'lace', 'nolace']:
shutil.copyfile(os.path.join(folder, cond, f"{item}_{br}_{cond}.wav"), os.path.join(dir, cond, f"{item}_{br}.wav"))
nomad_opus, ref_embeddings = nomad_wrapper(cleandir, opusdir, full_reference=full_reference, ref_embeddings=None)
nomad_lace, ref_embeddings = nomad_wrapper(cleandir, lacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
nomad_nolace, ref_embeddings = nomad_wrapper(cleandir, nolacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
results = dict()
for br in bitrates:
results[br] = np.zeros((len(items), 3))
for i, item in enumerate(items):
key = f"{item}_{br}"
results[br][i, 0] = nomad_opus[key]
results[br][i, 1] = nomad_lace[key]
results[br][i, 2] = nomad_nolace[key]
return results
if __name__ == "__main__":
args = parser.parse_args()
items = get_itemlist(args.folder)
bitrates = get_bitrates(args.folder)
results = nomad_process_all(args.folder, full_reference=args.full_reference, device=args.device)
np.save(os.path.join(args.folder, f'results_nomad.npy'), results)
print("Done.")

View File

@@ -0,0 +1,196 @@
import os
import argparse
import yaml
import subprocess
import numpy as np
from moc2 import compare as moc
DEBUG=False
parser = argparse.ArgumentParser()
parser.add_argument('inputdir', type=str, help='Input folder with test items')
parser.add_argument('outputdir', type=str, help='Output folder')
parser.add_argument('bitrate', type=int, help='bitrate to test')
parser.add_argument('--reference_opus_demo', type=str, default='./opus_demo', help='reference opus_demo binary for generating bitstreams and reference output')
parser.add_argument('--encoder_options', type=str, default="", help='encoder options (e.g. -complexity 5)')
parser.add_argument('--test_opus_demo', type=str, default='./opus_demo', help='opus_demo binary under test')
parser.add_argument('--test_opus_demo_options', type=str, default='-dec_complexity 7', help='options for test opus_demo (e.g. "-dec_complexity 7")')
parser.add_argument('--verbose', type=int, default=0, help='verbosity level: 0 for quiet (default), 1 for reporting individual test results, 2 for reporting per-item scores in failed tests')
def run_opus_encoder(opus_demo_path, input_pcm_path, bitstream_path, application, fs, num_channels, bitrate, options=[], verbose=False):
call_args = [
opus_demo_path,
"-e",
application,
str(fs),
str(num_channels),
str(bitrate),
"-bandwidth",
"WB"
]
call_args += options
call_args += [
input_pcm_path,
bitstream_path
]
try:
if verbose:
print(f"running {call_args}...")
subprocess.run(call_args)
else:
subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except:
return 1
return 0
def run_opus_decoder(opus_demo_path, bitstream_path, output_pcm_path, fs, num_channels, options=[], verbose=False):
call_args = [
opus_demo_path,
"-d",
str(fs),
str(num_channels)
]
call_args += options
call_args += [
bitstream_path,
output_pcm_path
]
try:
if verbose:
print(f"running {call_args}...")
subprocess.run(call_args)
else:
subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except:
return 1
return 0
def compute_moc_score(reference_pcm, test_pcm, delay=91):
x_ref = np.fromfile(reference_pcm, dtype=np.int16).astype(np.float32) / (2 ** 15)
x_cut = np.fromfile(test_pcm, dtype=np.int16).astype(np.float32) / (2 ** 15)
moc_score = moc(x_ref, x_cut[delay:])
return moc_score
def sox(*call_args):
try:
call_args = ["sox"] + list(call_args)
subprocess.run(call_args)
return 0
except:
return 1
def process_clip_factory(ref_opus_demo, test_opus_demo, enc_options, test_options):
def process_clip(clip_path, processdir, bitrate):
# derive paths
clipname = os.path.splitext(os.path.split(clip_path)[1])[0]
pcm_path = os.path.join(processdir, clipname + ".raw")
bitstream_path = os.path.join(processdir, clipname + ".bin")
ref_path = os.path.join(processdir, clipname + "_ref.raw")
test_path = os.path.join(processdir, clipname + "_test.raw")
# run sox
sox(clip_path, pcm_path)
# run encoder
run_opus_encoder(ref_opus_demo, pcm_path, bitstream_path, "voip", 16000, 1, bitrate, enc_options)
# run decoder
run_opus_decoder(ref_opus_demo, bitstream_path, ref_path, 16000, 1)
run_opus_decoder(test_opus_demo, bitstream_path, test_path, 16000, 1, options=test_options)
d_ref = compute_moc_score(pcm_path, ref_path)
d_test = compute_moc_score(pcm_path, test_path)
return d_ref, d_test
return process_clip
def main(inputdir, outputdir, bitrate, reference_opus_demo, test_opus_demo, enc_option_string, test_option_string, verbose):
# load clips list
with open(os.path.join(inputdir, 'clips.yml'), "r") as f:
clips = yaml.safe_load(f)
# parse test options
enc_options = enc_option_string.split()
test_options = test_option_string.split()
process_clip = process_clip_factory(reference_opus_demo, test_opus_demo, enc_options, test_options)
os.makedirs(outputdir, exist_ok=True)
processdir = os.path.join(outputdir, 'process')
os.makedirs(processdir, exist_ok=True)
num_passed = 0
results = dict()
min_rel_diff = 1000
min_mean = 1000
worst_clip = None
worst_lang = None
for lang, lang_clips in clips.items():
if verbose > 0: print(f"processing language {lang}...")
results[lang] = np.zeros((len(lang_clips), 2))
for i, clip in enumerate(lang_clips):
clip_path = os.path.join(inputdir, clip)
d_ref, d_test = process_clip(clip_path, processdir, bitrate)
results[lang][i, 0] = d_ref
results[lang][i, 1] = d_test
alpha = 0.5
rel_diff = ((results[lang][:, 0] ** alpha - results[lang][:, 1] ** alpha) /(results[lang][:, 0] ** alpha))
min_idx = np.argmin(rel_diff).item()
if rel_diff[min_idx] < min_rel_diff:
min_rel_diff = rel_diff[min_idx]
worst_clip = lang_clips[min_idx]
if np.mean(rel_diff) < min_mean:
min_mean = np.mean(rel_diff).item()
worst_lang = lang
if np.min(rel_diff) < -0.1 or np.mean(rel_diff) < -0.025:
if verbose > 0: print(f"FAIL ({np.mean(results[lang], axis=0)} {np.mean(rel_diff)} {np.min(rel_diff)})")
if verbose > 1:
for i, c in enumerate(lang_clips):
print(f" {c:50s} {results[lang][i]} {rel_diff[i]}")
else:
if verbose > 0: print(f"PASS ({np.mean(results[lang], axis=0)} {np.mean(rel_diff)} {np.min(rel_diff)})")
num_passed += 1
print(f"{num_passed}/{len(clips)} tests passed!")
print(f"worst case occured at clip {worst_clip} with relative difference of {min_rel_diff}")
print(f"worst mean relative difference was {min_mean} for test {worst_lang}")
np.save(os.path.join(outputdir, f'results_' + "_".join(test_options) + f"_{bitrate}.npy"), results, allow_pickle=True)
if __name__ == "__main__":
args = parser.parse_args()
main(args.inputdir,
args.outputdir,
args.bitrate,
args.reference_opus_demo,
args.test_opus_demo,
args.encoder_options,
args.test_opus_demo_options,
args.verbose)

View File

@@ -0,0 +1,205 @@
""" module for inspecting models during inference """
import os
import yaml
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import numpy as np
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
_state = dict()
_folder = 'endoscopy'
def get_gru_gates(gru, input, state):
hidden_size = gru.hidden_size
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
# reset gate
start, stop = 0 * hidden_size, 1 * hidden_size
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# update gate
start, stop = 1 * hidden_size, 2 * hidden_size
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# new gate
start, stop = 2 * hidden_size, 3 * hidden_size
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
def init(folder='endoscopy'):
""" sets up output folder for endoscopy data """
global _folder
_folder = folder
if not os.path.exists(folder):
os.makedirs(folder)
else:
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
def write_data(key, data, fs):
""" appends data to previous data written under key """
global _state
# convert to numpy if torch.Tensor is given
if isinstance(data, torch.Tensor):
data = data.detach().numpy()
if not key in _state:
_state[key] = {
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
'fs' : fs,
'dim' : tuple(data.shape),
'dtype' : str(data.dtype)
}
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
else:
if _state[key]['fs'] != fs:
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
if _state[key]['dtype'] != str(data.dtype):
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
if _state[key]['dim'] != tuple(data.shape):
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
_state[key]['fid'].write(data.tobytes())
def close(folder='endoscopy'):
""" clean up """
for key in _state.keys():
_state[key]['fid'].close()
def read_data(folder='endoscopy'):
""" retrieves written data as numpy arrays """
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
return_dict = dict()
for key in keys:
with open(os.path.join(folder, key + '.yml'), 'r') as f:
value = yaml.load(f.read(), yaml.FullLoader)
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
data = np.frombuffer(f.read(), dtype=value['dtype'])
value['data'] = data.reshape((-1,) + value['dim'])
return_dict[key] = value
return return_dict
def get_best_reshape(shape, target_ratio=1):
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return (1,)
num_columns = int((pixel_count / target_ratio)**.5)
while (pixel_count % num_columns):
num_columns -= 1
num_rows = pixel_count // num_columns
return (num_rows, num_columns)
def get_type_and_shape(shape):
# can happen if data is one dimensional
if len(shape) == 0:
shape = (1,)
# calculate pixel count
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return 'plot', (1, )
# stay with shape if already 2-dimensional
if len(shape) == 2:
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
return 'image', shape
return 'image', get_best_reshape(shape)
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
# determine plot setup
num_keys = len(data.keys())
num_rows = int((num_keys * 3/4) ** .5)
num_cols = (num_keys + num_rows - 1) // num_rows
fig, axs = plt.subplots(num_rows, num_cols)
fig.set_size_inches(num_cols * 5, num_rows * 5)
display = dict()
fs_max = max([val['fs'] for val in data.values()])
num_samples = max([val['data'].shape[0] for val in data.values()])
keys = sorted(data.keys())
# inspect data
for i, key in enumerate(keys):
axs[i // num_cols, i % num_cols].title.set_text(key)
display[key] = dict()
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
display[key]['down_factor'] = data[key]['fs'] / fs_max
start_index = max(start_index, half_signal_window_length)
while stop_index < 0:
stop_index += num_samples
stop_index = min(stop_index, num_samples - half_signal_window_length)
# actual plotting
frames = []
for index in range(start_index, stop_index):
ims = []
for i, key in enumerate(keys):
feature_index = int(round(index * display[key]['down_factor']))
if display[key]['type'] == 'plot':
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
elif display[key]['type'] == 'image':
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
frames.append(ims)
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
if not filename.endswith('.mp4'):
filename += '.mp4'
ani.save(filename)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,25 @@
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation
def make_playback_animation(savepath, spec, duration_ms, vmin=20, vmax=90):
fig, axs = plt.subplots()
axs.set_axis_off()
fig.set_size_inches((duration_ms / 1000 * 5, 5))
frames = []
frame_duration=20
num_frames = int(duration_ms / frame_duration + .99)
spec_height, spec_width = spec.shape
for i in range(num_frames):
xpos = (i - 1) / (num_frames - 3) * (spec_width - 1)
new_frame = axs.imshow(spec, cmap='inferno', origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
if i in {0, num_frames - 1}:
frames.append([new_frame])
else:
line = axs.plot([xpos, xpos], [0, spec_height-1], color='white', alpha=0.8)[0]
frames.append([new_frame, line])
ani = matplotlib.animation.ArtistAnimation(fig, frames, blit=True, interval=frame_duration)
ani.save(savepath, dpi=720)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,96 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import argparse
import torch
from scipy.io import wavfile
from models import model_dict
from utils.silk_features import load_inference_data
from utils import endoscopy
debug = False
if debug:
args = type('dummy', (object,),
{
'input' : 'testitems/all_0_orig.se',
'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth',
'output' : 'out.wav',
})()
else:
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str, help='path to folder with features and signals')
parser.add_argument('checkpoint', type=str, help='checkpoint file')
parser.add_argument('output', type=str, help='output file')
parser.add_argument('--debug', action='store_true', help='enables debug output')
args = parser.parse_args()
torch.set_num_threads(2)
input_folder = args.input
checkpoint_file = args.checkpoint
output_file = args.output
if not output_file.endswith('.wav'):
output_file += '.wav'
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# check model
if not 'name' in checkpoint['setup']['model']:
print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
model_name = 'pitchpostfilter'
else:
model_name = checkpoint['setup']['model']['name']
model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
model.load_state_dict(checkpoint['state_dict'])
# generate model input
setup = checkpoint['setup']
signal, features, periods, numbits = load_inference_data(input_folder, **setup['data'])
if args.debug:
endoscopy.init()
output = model.process(signal, features, periods, numbits, debug=args.debug)
wavfile.write(output_file, 16000, output.cpu().numpy())
if args.debug:
endoscopy.close()

View File

@@ -0,0 +1,103 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import argparse
import torch
from scipy.io import wavfile
from time import time
from models import model_dict
from utils.lpcnet_features import load_lpcnet_features
from utils import endoscopy
debug = False
if debug:
args = type('dummy', (object,),
{
'input' : 'testitems/all_0_orig.se',
'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth',
'output' : 'out.wav',
})()
else:
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str, help='path to input features')
parser.add_argument('checkpoint', type=str, help='checkpoint file')
parser.add_argument('output', type=str, help='output file')
parser.add_argument('--debug', action='store_true', help='enables debug output')
args = parser.parse_args()
torch.set_num_threads(2)
input_folder = args.input
checkpoint_file = args.checkpoint
output_file = args.output
if not output_file.endswith('.wav'):
output_file += '.wav'
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# check model
if not 'name' in checkpoint['setup']['model']:
print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
model_name = 'pitchpostfilter'
else:
model_name = checkpoint['setup']['model']['name']
model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
model.load_state_dict(checkpoint['state_dict'])
# generate model input
setup = checkpoint['setup']
testdata = load_lpcnet_features(input_folder)
features = testdata['features']
periods = testdata['periods']
if args.debug:
endoscopy.init()
start = time()
output = model.process(features, periods, debug=args.debug)
elapsed = time() - start
print(f"[timing] inference took {elapsed * 1000} ms")
wavfile.write(output_file, 16000, output.cpu().numpy())
if args.debug:
endoscopy.close()

View File

@@ -0,0 +1,307 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
seed=1888
import os
import argparse
import sys
import random
random.seed(seed)
import yaml
try:
import git
has_git = True
except:
has_git = False
import torch
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
np.random.seed(seed)
from scipy.io import wavfile
import pesq
from data import SilkEnhancementSet
from models import model_dict
from engine.engine import train_one_epoch, evaluate
from utils.silk_features import load_inference_data
from utils.misc import count_parameters, count_nonzero_parameters
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'name' in setup['model']:
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
model_name = 'pitchpostfilter'
else:
model_name = setup['model']['name']
# prepare output folder
if os.path.exists(args.output):
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit(0)
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
with open(os.path.join(args.output, 'repo.diff'), "w") as f:
f.write(repo.git.execute(["git", "diff"]))
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
ref = None
if args.testdata is not None:
testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
inference_test = True
inference_folder = os.path.join(args.output, 'inference_test')
os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
try:
ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
except:
pass
else:
inference_test = False
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
# load training dataset
data_config = setup['data']
data = SilkEnhancementSet(setup['dataset'], **data_config)
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(chkpt['state_dict'])
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# push model to device
model.to(device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr)
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
# loss
w_l1 = setup['training']['loss']['w_l1']
w_lm = setup['training']['loss']['w_lm']
w_slm = setup['training']['loss']['w_slm']
w_sc = setup['training']['loss']['w_sc']
w_logmel = setup['training']['loss']['w_logmel']
w_wsc = setup['training']['loss']['w_wsc']
w_xcorr = setup['training']['loss']['w_xcorr']
w_sxcorr = setup['training']['loss']['w_sxcorr']
w_l2 = setup['training']['loss']['w_l2']
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
logmelloss = MRLogMelLoss().to(device)
def xcorr_loss(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
return torch.mean(loss)
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
def td_l1(y_true, y_pred, pow=0):
dims = list(range(1, len(y_true.shape)))
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
return torch.mean(tmp)
def criterion(x, y):
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
print("summary:")
print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
if ref is not None:
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
print(f"initial MOS (PESQ): {initial_mos}")
best_loss = 1e9
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = new_loss
if run_validation:
print("running validation...")
validation_loss = evaluate(model, criterion, validation_dataloader, device)
checkpoint['validation_loss'] = validation_loss
if validation_loss < best_loss:
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
best_loss = validation_loss
if inference_test:
print("running inference test...")
out = model.process(testsignal, features, periods, numbits).cpu().numpy()
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
if ref is not None:
mos = pesq.pesq(16000, ref, out, mode='wb')
print(f"MOS (PESQ): {mos}")
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
print(f"non-zero parameters: {count_nonzero_parameters(model)}\n")
print('Done')

View File

@@ -0,0 +1,287 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import sys
import yaml
try:
import git
has_git = True
except:
has_git = False
import torch
from torch.optim.lr_scheduler import LambdaLR
from scipy.io import wavfile
import pesq
from data import LPCNetVocodingDataset
from models import model_dict
from engine.vocoder_engine import train_one_epoch, evaluate
from utils.lpcnet_features import load_lpcnet_features
from utils.misc import count_parameters
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
parser = argparse.ArgumentParser()
parser.add_argument('setup', type=str, help='setup yaml file')
parser.add_argument('output', type=str, help='output path')
parser.add_argument('--device', type=str, help='compute device', default=None)
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
args = parser.parse_args()
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
setup = yaml.load(f.read(), yaml.FullLoader)
checkpoint_prefix = 'checkpoint'
output_prefix = 'output'
setup_name = 'setup.yml'
output_file='out.txt'
# check model
if not 'name' in setup['model']:
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
model_name = 'pitchpostfilter'
else:
model_name = setup['model']['name']
# prepare output folder
if os.path.exists(args.output):
print("warning: output folder exists")
reply = input('continue? (y/n): ')
while reply not in {'y', 'n'}:
reply = input('continue? (y/n): ')
if reply == 'n':
os._exit()
else:
os.makedirs(args.output, exist_ok=True)
checkpoint_dir = os.path.join(args.output, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# add repo info to setup
if has_git:
working_dir = os.path.split(__file__)[0]
try:
repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
is_dirty = repo.is_dirty()
if is_dirty:
print("warning: repo is dirty")
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
setup['repo']['dirty'] = is_dirty
except:
has_git = False
# dump setup
with open(os.path.join(args.output, setup_name), 'w') as f:
yaml.dump(setup, f)
ref = None
# prepare inference test if wanted
inference_test = False
if type(args.test_features) != type(None):
test_features = load_lpcnet_features(args.test_features)
features = test_features['features']
periods = test_features['periods']
inference_folder = os.path.join(args.output, 'inference_test')
os.makedirs(inference_folder, exist_ok=True)
inference_test = True
# training parameters
batch_size = setup['training']['batch_size']
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
# load training dataset
data_config = setup['data']
data = LPCNetVocodingDataset(setup['dataset'], **data_config)
# load validation dataset if given
if 'validation_dataset' in setup:
validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
run_validation = True
else:
run_validation = False
# create model
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
if args.initial_checkpoint is not None:
print(f"loading state dict from {args.initial_checkpoint}...")
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(chkpt['state_dict'])
# set compute device
if type(args.device) == type(None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# push model to device
model.to(device)
# dataloader
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
# optimizer is introduced to trainable parameters
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameters, lr=lr)
# learning rate scheduler
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
# loss
w_l1 = setup['training']['loss']['w_l1']
w_lm = setup['training']['loss']['w_lm']
w_slm = setup['training']['loss']['w_slm']
w_sc = setup['training']['loss']['w_sc']
w_logmel = setup['training']['loss']['w_logmel']
w_wsc = setup['training']['loss']['w_wsc']
w_xcorr = setup['training']['loss']['w_xcorr']
w_sxcorr = setup['training']['loss']['w_sxcorr']
w_l2 = setup['training']['loss']['w_l2']
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
logmelloss = MRLogMelLoss().to(device)
def xcorr_loss(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
return torch.mean(loss)
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
def td_l1(y_true, y_pred, pow=0):
dims = list(range(1, len(y_true.shape)))
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
return torch.mean(tmp)
def criterion(x, y):
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
# model checkpoint
checkpoint = {
'setup' : setup,
'state_dict' : model.state_dict(),
'loss' : -1
}
if not args.no_redirect:
print(f"re-directing output to {os.path.join(args.output, output_file)}")
sys.stdout = open(os.path.join(args.output, output_file), "w")
print("summary:")
print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
if hasattr(model, 'flop_count'):
print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
if ref is not None:
pass
best_loss = 1e9
for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
# save checkpoint
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = new_loss
if run_validation:
print("running validation...")
validation_loss = evaluate(model, criterion, validation_dataloader, device)
checkpoint['validation_loss'] = validation_loss
if validation_loss < best_loss:
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
best_loss = validation_loss
if inference_test:
print("running inference test...")
out = model.process(features, periods).cpu().numpy()
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
if ref is not None:
mos = pesq.pesq(16000, ref, out, mode='wb')
print(f"MOS (PESQ): {mos}")
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
print()
print('Done')

View File

@@ -0,0 +1,71 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jean-Marc Valin */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
# x is (batch, nb_in_channels, nb_frames*frame_size)
# kernels is (batch, nb_out_channels, nb_in_channels, nb_frames, coeffs)
def adaconv_kernel(x, kernels, half_window, fft_size=256):
device=x.device
overlap_size=half_window.size(-1)
nb_frames=kernels.size(3)
nb_batches=kernels.size(0)
nb_out_channels=kernels.size(1)
nb_in_channels=kernels.size(2)
kernel_size = kernels.size(-1)
x = x.reshape(nb_batches, 1, nb_in_channels, nb_frames, -1)
frame_size = x.size(-1)
# build window: [zeros, rising window, ones, falling window, zeros]
window = torch.cat(
[
torch.zeros(frame_size, device=device),
half_window,
torch.ones(frame_size - overlap_size, device=device),
1 - half_window,
torch.zeros(fft_size - 2 * frame_size - overlap_size,device=device)
])
x_prev = torch.cat([torch.zeros_like(x[:, :, :, :1, :]), x[:, :, :, :-1, :]], dim=-2)
x_next = torch.cat([x[:, :, :, 1:, :overlap_size], torch.zeros_like(x[:, :, :, -1:, :overlap_size])], dim=-2)
x_padded = torch.cat([x_prev, x, x_next, torch.zeros(nb_batches, 1, nb_in_channels, nb_frames, fft_size - 2 * frame_size - overlap_size, device=device)], -1)
k_padded = torch.cat([torch.flip(kernels, [-1]), torch.zeros(nb_batches, nb_out_channels, nb_in_channels, nb_frames, fft_size-kernel_size, device=device)], dim=-1)
# compute convolution
X = torch.fft.rfft(x_padded, dim=-1)
K = torch.fft.rfft(k_padded, dim=-1)
out = torch.fft.irfft(X * K, dim=-1)
# combine in channels
out = torch.sum(out, dim=2)
# apply the cross-fading
out = window.reshape(1, 1, 1, -1)*out
crossfaded = out[:,:,:,frame_size:2*frame_size] + torch.cat([torch.zeros(nb_batches, nb_out_channels, 1, frame_size, device=device), out[:, :, :-1, 2*frame_size:3*frame_size]], dim=-2)
return crossfaded.reshape(nb_batches, nb_out_channels, -1)

View File

@@ -0,0 +1,8 @@
def _conv1d_flop_count(layer, rate):
return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0]
def _dense_flop_count(layer, rate):
return 2 * ((layer.in_features + 1) * layer.out_features * rate )

View File

@@ -0,0 +1,205 @@
""" module for inspecting models during inference """
import os
import yaml
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import numpy as np
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
_state = dict()
_folder = 'endoscopy'
def get_gru_gates(gru, input, state):
hidden_size = gru.hidden_size
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
# reset gate
start, stop = 0 * hidden_size, 1 * hidden_size
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# update gate
start, stop = 1 * hidden_size, 2 * hidden_size
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# new gate
start, stop = 2 * hidden_size, 3 * hidden_size
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
def init(folder='endoscopy'):
""" sets up output folder for endoscopy data """
global _folder
_folder = folder
if not os.path.exists(folder):
os.makedirs(folder)
else:
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
def write_data(key, data, fs):
""" appends data to previous data written under key """
global _state
# convert to numpy if torch.Tensor is given
if isinstance(data, torch.Tensor):
data = data.detach().numpy()
if not key in _state:
_state[key] = {
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
'fs' : fs,
'dim' : tuple(data.shape),
'dtype' : str(data.dtype)
}
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
else:
if _state[key]['fs'] != fs:
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
if _state[key]['dtype'] != str(data.dtype):
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
if _state[key]['dim'] != tuple(data.shape):
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
_state[key]['fid'].write(data.tobytes())
def close(folder='endoscopy'):
""" clean up """
for key in _state.keys():
_state[key]['fid'].close()
def read_data(folder='endoscopy'):
""" retrieves written data as numpy arrays """
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
return_dict = dict()
for key in keys:
with open(os.path.join(folder, key + '.yml'), 'r') as f:
value = yaml.load(f.read(), yaml.FullLoader)
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
data = np.frombuffer(f.read(), dtype=value['dtype'])
value['data'] = data.reshape((-1,) + value['dim'])
return_dict[key] = value
return return_dict
def get_best_reshape(shape, target_ratio=1):
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return (1,)
num_columns = int((pixel_count / target_ratio)**.5)
while (pixel_count % num_columns):
num_columns -= 1
num_rows = pixel_count // num_columns
return (num_rows, num_columns)
def get_type_and_shape(shape):
# can happen if data is one dimensional
if len(shape) == 0:
shape = (1,)
# calculate pixel count
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return 'plot', (1, )
# stay with shape if already 2-dimensional
if len(shape) == 2:
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
return 'image', shape
return 'image', get_best_reshape(shape)
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
# determine plot setup
num_keys = len(data.keys())
num_rows = int((num_keys * 3/4) ** .5)
num_cols = (num_keys + num_rows - 1) // num_rows
fig, axs = plt.subplots(num_rows, num_cols)
fig.set_size_inches(num_cols * 5, num_rows * 5)
display = dict()
fs_max = max([val['fs'] for val in data.values()])
num_samples = max([val['data'].shape[0] for val in data.values()])
keys = sorted(data.keys())
# inspect data
for i, key in enumerate(keys):
axs[i // num_cols, i % num_cols].title.set_text(key)
display[key] = dict()
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
display[key]['down_factor'] = data[key]['fs'] / fs_max
start_index = max(start_index, half_signal_window_length)
while stop_index < 0:
stop_index += num_samples
stop_index = min(stop_index, num_samples - half_signal_window_length)
# actual plotting
frames = []
for index in range(start_index, stop_index):
ims = []
for i, key in enumerate(keys):
feature_index = int(round(index * display[key]['down_factor']))
if display[key]['type'] == 'plot':
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
elif display[key]['type'] == 'image':
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
frames.append(ims)
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
if not filename.endswith('.mp4'):
filename += '.mp4'
ani.save(filename)

View File

@@ -0,0 +1,27 @@
import numpy as np
import scipy.signal
import torch
from torch import nn
import torch.nn.functional as F
class FIR(nn.Module):
def __init__(self, numtaps, bands, desired, fs=2):
super().__init__()
if numtaps % 2 == 0:
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
numtaps += 1
a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
self.weight = torch.from_numpy(a.astype(np.float32))
def forward(self, x):
num_channels = x.size(1)
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
y = F.conv1d(x, weight, groups=num_channels)
return y

View File

@@ -0,0 +1,230 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.endoscopy import write_data
from utils.softquant import soft_quant
class LimitedAdaptiveComb1d(nn.Module):
COUNTER = 1
def __init__(self,
kernel_size,
feature_dim,
frame_size=160,
overlap_size=40,
padding=None,
max_lag=256,
name=None,
gain_limit_db=10,
global_gain_limits_db=[-6, 6],
norm_p=2,
softquant=False,
apply_weight_norm=False,
**kwargs):
"""
Parameters:
-----------
feature_dim : int
dimension of features from which kernels, biases and gains are computed
frame_size : int, optional
frame size, defaults to 160
overlap_size : int, optional
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40
use_bias : bool, optional
if true, biases will be added to output channels. Defaults to True
padding : List[int, int], optional
left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2]
max_lag : int, optional
maximal pitch lag, defaults to 256
have_a0 : bool, optional
If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False
name: str or None, optional
specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d
"""
super(LimitedAdaptiveComb1d, self).__init__()
self.in_channels = 1
self.out_channels = 1
self.feature_dim = feature_dim
self.kernel_size = kernel_size
self.frame_size = frame_size
self.overlap_size = overlap_size
self.max_lag = max_lag
self.limit_db = gain_limit_db
self.norm_p = norm_p
if name is None:
self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER)
LimitedAdaptiveComb1d.COUNTER += 1
else:
self.name = name
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
# network for generating convolution weights
self.conv_kernel = norm(nn.Linear(feature_dim, kernel_size))
if softquant:
self.conv_kernel = soft_quant(self.conv_kernel)
# comb filter gain
self.filter_gain = norm(nn.Linear(feature_dim, 1))
self.log_gain_limit = gain_limit_db * 0.11512925464970229
with torch.no_grad():
self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
self.global_filter_gain = norm(nn.Linear(feature_dim, 1))
log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
self.filter_gain_a = (log_max - log_min) / 2
self.filter_gain_b = (log_max + log_min) / 2
if type(padding) == type(None):
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
else:
self.padding = padding
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
def forward(self, x, features, lags, debug=False):
""" adaptive 1d convolution
Parameters:
-----------
x : torch.tensor
input signal of shape (batch_size, in_channels, num_samples)
feathres : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
lags: torch.LongTensor
frame-wise lags for comb-filtering
"""
batch_size = x.size(0)
num_frames = features.size(1)
num_samples = x.size(2)
frame_size = self.frame_size
overlap_size = self.overlap_size
kernel_size = self.kernel_size
win1 = torch.flip(self.overlap_win, [0])
win2 = self.overlap_win
if num_samples // self.frame_size != num_frames:
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
# calculate gains
global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
if debug and batch_size == 1:
key = self.name + "_gains"
write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_kernels"
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_lags"
write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_global_conv_gains"
write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
# frame-wise convolution with overlap-add
output_frames = []
overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
x = F.pad(x, self.padding)
x = F.pad(x, [self.max_lag, self.overlap_size])
idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1)
idx = torch.repeat_interleave(idx, batch_size, 0)
idx = torch.repeat_interleave(idx, self.in_channels, 1)
for i in range(num_frames):
cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1)
xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1))
new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
offset = self.max_lag + self.padding[0]
new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
# overlapping part
output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
# non-overlapping part
output_frames.append(new_chunk[:, :, overlap_size : frame_size])
# mem for next frame
overlap_mem = new_chunk[:, :, frame_size :]
# concatenate chunks
output = torch.cat(output_frames, dim=-1)
return output
def flop_count(self, rate):
frame_rate = rate / self.frame_size
overlap = self.overlap_size
overhead = overlap / self.frame_size
count = 0
# kernel computation and filtering
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
# a0 computation
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
# windowing
count += overlap * frame_rate * 3 * self.out_channels
return count

View File

@@ -0,0 +1,200 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.endoscopy import write_data
from utils.ada_conv import adaconv_kernel
from utils.softquant import soft_quant
class LimitedAdaptiveConv1d(nn.Module):
COUNTER = 1
def __init__(self,
in_channels,
out_channels,
kernel_size,
feature_dim,
frame_size=160,
overlap_size=40,
padding=None,
name=None,
gain_limits_db=[-6, 6],
shape_gain_db=0,
norm_p=2,
softquant=False,
apply_weight_norm=False,
**kwargs):
"""
Parameters:
-----------
in_channels : int
number of input channels
out_channels : int
number of output channels
feature_dim : int
dimension of features from which kernels, biases and gains are computed
frame_size : int
frame size
overlap_size : int
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame
use_bias : bool
if true, biases will be added to output channels
padding : List[int, int]
"""
super(LimitedAdaptiveConv1d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.feature_dim = feature_dim
self.kernel_size = kernel_size
self.frame_size = frame_size
self.overlap_size = overlap_size
self.gain_limits_db = gain_limits_db
self.shape_gain_db = shape_gain_db
self.norm_p = norm_p
if name is None:
self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER)
LimitedAdaptiveConv1d.COUNTER += 1
else:
self.name = name
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
# network for generating convolution weights
self.conv_kernel = norm(nn.Linear(feature_dim, in_channels * out_channels * kernel_size))
if softquant:
self.conv_kernel = soft_quant(self.conv_kernel)
self.shape_gain = min(1, 10**(shape_gain_db / 20))
self.filter_gain = norm(nn.Linear(feature_dim, out_channels))
log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
self.filter_gain_a = (log_max - log_min) / 2
self.filter_gain_b = (log_max + log_min) / 2
if type(padding) == type(None):
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
else:
self.padding = padding
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
def flop_count(self, rate):
frame_rate = rate / self.frame_size
overlap = self.overlap_size
overhead = overlap / self.frame_size
count = 0
# kernel computation and filtering
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
# gain computation
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
# windowing
count += 3 * overlap * frame_rate * self.out_channels
return count
def forward(self, x, features, debug=False):
""" adaptive 1d convolution
Parameters:
-----------
x : torch.tensor
input signal of shape (batch_size, in_channels, num_samples)
feathres : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
"""
batch_size = x.size(0)
num_frames = features.size(1)
num_samples = x.size(2)
frame_size = self.frame_size
overlap_size = self.overlap_size
kernel_size = self.kernel_size
win1 = torch.flip(self.overlap_win, [0])
win2 = self.overlap_win
if num_samples // self.frame_size != num_frames:
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
# normalize kernels (TODO: switch to L1 and normalize over kernel and input channel dimension)
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=[-2, -1], keepdim=True))
# limit shape
id_kernels = torch.zeros_like(conv_kernels)
id_kernels[..., self.padding[1]] = 1
conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels
# calculate gains
conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features)) + self.filter_gain_b)
if debug and batch_size == 1:
key = self.name + "_gains"
write_data(key, conv_gains.permute(0, 2, 1).detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_kernels"
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
conv_kernels = conv_kernels * conv_gains.view(batch_size, num_frames, self.out_channels, 1, 1)
conv_kernels = conv_kernels.permute(0, 2, 3, 1, 4)
output = adaconv_kernel(x, conv_kernels, win1, fft_size=256)
return output

View File

@@ -0,0 +1,100 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
class NoiseShaper(nn.Module):
def __init__(self,
feature_dim,
frame_size=160
):
"""
Parameters:
-----------
feature_dim : int
dimension of input features
frame_size : int
frame size
"""
super().__init__()
self.feature_dim = feature_dim
self.frame_size = frame_size
# feature transform
self.feature_alpha1 = nn.Conv1d(self.feature_dim, frame_size, 2)
self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2)
def flop_count(self, rate):
frame_rate = rate / self.frame_size
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
return shape_flops
def forward(self, features):
""" creates temporally shaped noise
Parameters:
-----------
features : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
"""
batch_size = features.size(0)
num_frames = features.size(1)
frame_size = self.frame_size
num_samples = num_frames * frame_size
# feature path
f = F.pad(features.permute(0, 2, 1), [1, 0])
alpha = F.leaky_relu(self.feature_alpha1(f), 0.2)
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
alpha = alpha.permute(0, 2, 1)
# signal generation
y = torch.randn((batch_size, num_frames, frame_size), dtype=features.dtype, device=features.device)
y = alpha * y
return y.reshape(batch_size, 1, num_samples)

View File

@@ -0,0 +1,84 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
import torch.nn.functional as F
class PitchAutoCorrelator(nn.Module):
def __init__(self,
frame_size=80,
pitch_min=32,
pitch_max=300,
radius=2):
super().__init__()
self.frame_size = frame_size
self.pitch_min = pitch_min
self.pitch_max = pitch_max
self.radius = radius
def forward(self, x, periods):
# x of shape (batch_size, channels, num_samples)
# periods of shape (batch_size, num_frames)
num_frames = periods.size(1)
batch_size = periods.size(0)
num_samples = self.frame_size * num_frames
channels = x.size(1)
assert num_samples == x.size(-1)
range = torch.arange(-self.radius, self.radius + 1, device=x.device)
idx = torch.arange(self.frame_size * num_frames, device=x.device)
p_up = torch.repeat_interleave(periods, self.frame_size, 1)
lookup = idx + self.pitch_max - p_up
lookup = lookup.unsqueeze(-1) + range
lookup = lookup.unsqueeze(1)
# padding
x_pad = F.pad(x, [self.pitch_max, 0])
x_ext = torch.repeat_interleave(x_pad.unsqueeze(-1), 2 * self.radius + 1, -1)
# framing
x_select = torch.gather(x_ext, 2, lookup)
x_frames = x_pad[..., self.pitch_max : ].reshape(batch_size, channels, num_frames, self.frame_size, 1)
lag_frames = x_select.reshape(batch_size, 1, num_frames, self.frame_size, -1)
# calculate auto-correlation
dotp = torch.sum(x_frames * lag_frames, dim=-2)
frame_nrg = torch.sum(x_frames * x_frames, dim=-2)
lag_frame_nrg = torch.sum(lag_frames * lag_frames, dim=-2)
acorr = dotp / torch.sqrt(frame_nrg * lag_frame_nrg + 1e-9)
return acorr

View File

@@ -0,0 +1,167 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
frac_fir = np.array(
[
[189, -600, 617, 30567, 2996, -1375, 425, -46],
[117, -159, -1070, 29704, 5784, -2143, 611, -71],
[52, 221, -2392, 28276, 8798, -2865, 773, -91],
[-4, 529, -3350, 26341, 11950, -3487, 896, -103],
[-48, 758, -3956, 23973, 15143, -3957, 967, -107],
[-80, 905, -4235, 21254, 18278, -4222, 972, -99],
[-99, 972, -4222, 18278, 21254, -4235, 905, -80],
[-107, 967, -3957, 15143, 23973, -3956, 758, -48],
[-103, 896, -3487, 11950, 26341, -3350, 529, -4],
[-91, 773, -2865, 8798, 28276, -2392, 221, 52],
[-71, 611, -2143, 5784, 29704, -1070, -159, 117],
[-46, 425, -1375, 2996, 30567, 617, -600, 189]
],
dtype=np.float32
) / 2**15
hq_2x_up_c_even = [x / 2**16 for x in [1746, 14986, 39083 - 65536]]
hq_2x_up_c_odd = [x / 2**16 for x in [6854, 25769, 55542 - 65536]]
def get_impz(coeffs, n):
s = 3*[0]
y = np.zeros(n)
x = 1
for i in range(n):
Y = x - s[0]
X = Y * coeffs[0]
tmp1 = s[0] + X
s[0] = x + X
Y = tmp1 - s[1]
X = Y * coeffs[1]
tmp2 = s[1] + X
s[1] = tmp1 + X
Y = tmp2 - s[2]
X = Y * (1 + coeffs[2])
tmp3 = s[2] + X
s[2] = tmp2 + X
y[i] = tmp3
x = 0
return y
class SilkUpsampler(nn.Module):
SUPPORTED_TARGET_RATES = {24000, 48000}
SUPPORTED_SOURCE_RATES = {16000}
def __init__(self,
fs_in=16000,
fs_out=48000):
super().__init__()
self.fs_in = fs_in
self.fs_out = fs_out
if fs_in not in self.SUPPORTED_SOURCE_RATES:
raise ValueError(f'SilkUpsampler currently only supports upsampling from {self.SUPPORTED_SOURCE_RATES} Hz')
if fs_out not in self.SUPPORTED_TARGET_RATES:
raise ValueError(f'SilkUpsampler currently only supports upsampling to {self.SUPPORTED_TARGET_RATES} Hz')
# hq 2x upsampler as FIR approximation
hq_2x_up_even = get_impz(hq_2x_up_c_even, 128)[::-1].copy()
hq_2x_up_odd = get_impz(hq_2x_up_c_odd , 128)[::-1].copy()
self.hq_2x_up_even = nn.Parameter(torch.from_numpy(hq_2x_up_even).float().view(1, 1, -1), requires_grad=False)
self.hq_2x_up_odd = nn.Parameter(torch.from_numpy(hq_2x_up_odd ).float().view(1, 1, -1), requires_grad=False)
self.hq_2x_up_padding = [127, 0]
# interpolation filters
frac_01_24 = frac_fir[0]
frac_17_24 = frac_fir[8]
frac_09_24 = frac_fir[4]
self.frac_01_24 = nn.Parameter(torch.from_numpy(frac_01_24).view(1, 1, -1), requires_grad=False)
self.frac_17_24 = nn.Parameter(torch.from_numpy(frac_17_24).view(1, 1, -1), requires_grad=False)
self.frac_09_24 = nn.Parameter(torch.from_numpy(frac_09_24).view(1, 1, -1), requires_grad=False)
self.stride = 1 if fs_out == 48000 else 2
def hq_2x_up(self, x):
num_channels = x.size(1)
weight_even = torch.repeat_interleave(self.hq_2x_up_even, num_channels, 0)
weight_odd = torch.repeat_interleave(self.hq_2x_up_odd , num_channels, 0)
x_pad = F.pad(x, self.hq_2x_up_padding)
y_even = F.conv1d(x_pad, weight_even, groups=num_channels)
y_odd = F.conv1d(x_pad, weight_odd , groups=num_channels)
y = torch.cat((y_even.unsqueeze(-1), y_odd.unsqueeze(-1)), dim=-1).flatten(2)
return y
def interpolate_3_2(self, x):
num_channels = x.size(1)
weight_01_24 = torch.repeat_interleave(self.frac_01_24, num_channels, 0)
weight_17_24 = torch.repeat_interleave(self.frac_17_24, num_channels, 0)
weight_09_24 = torch.repeat_interleave(self.frac_09_24, num_channels, 0)
x_pad = F.pad(x, [8, 0])
y_01_24 = F.conv1d(x_pad, weight_01_24, stride=2, groups=num_channels)
y_17_24 = F.conv1d(x_pad, weight_17_24, stride=2, groups=num_channels)
y_09_24_sh1 = F.conv1d(torch.roll(x_pad, -1, -1), weight_09_24, stride=2, groups=num_channels)
y = torch.cat(
(y_01_24.unsqueeze(-1), y_17_24.unsqueeze(-1), y_09_24_sh1.unsqueeze(-1)),
dim=-1).flatten(2)
return y[..., :-3]
def forward(self, x):
y_2x = self.hq_2x_up(x)
y_3x = self.interpolate_3_2(y_2x)
return y_3x[:, :, ::self.stride]

View File

@@ -0,0 +1,145 @@
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
from utils.softquant import soft_quant
class TDShaper(nn.Module):
COUNTER = 1
def __init__(self,
feature_dim,
frame_size=160,
avg_pool_k=4,
innovate=False,
pool_after=False,
softquant=False,
apply_weight_norm=False
):
"""
Parameters:
-----------
feature_dim : int
dimension of input features
frame_size : int
frame size
avg_pool_k : int, optional
kernel size and stride for avg pooling
padding : List[int, int]
"""
super().__init__()
self.feature_dim = feature_dim
self.frame_size = frame_size
self.avg_pool_k = avg_pool_k
self.innovate = innovate
self.pool_after = pool_after
assert frame_size % avg_pool_k == 0
self.env_dim = frame_size // avg_pool_k + 1
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
# feature transform
self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))
self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2))
if softquant:
self.feature_alpha1_f = soft_quant(self.feature_alpha1_f)
if self.innovate:
self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2))
self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2))
def flop_count(self, rate):
frame_rate = rate / self.frame_size
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
if self.innovate:
inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
else:
inno_flops = 0
return shape_flops + inno_flops
def envelope_transform(self, x):
x = torch.abs(x)
if self.pool_after:
x = torch.log(x + .5**16)
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
else:
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
x = torch.log(x + .5**16)
x = x.reshape(x.size(0), -1, self.env_dim - 1)
avg_x = torch.mean(x, -1, keepdim=True)
x = torch.cat((x - avg_x, avg_x), dim=-1)
return x
def forward(self, x, features, debug=False):
""" innovate signal parts with temporal shaping
Parameters:
-----------
x : torch.tensor
input signal of shape (batch_size, 1, num_samples)
features : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
"""
batch_size = x.size(0)
num_frames = features.size(1)
num_samples = x.size(2)
frame_size = self.frame_size
# generate temporal envelope
tenv = self.envelope_transform(x)
# feature path
f = F.pad(features.permute(0, 2, 1), [1, 0])
t = F.pad(tenv.permute(0, 2, 1), [1, 0])
alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t)
alpha = F.leaky_relu(alpha, 0.2)
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
alpha = alpha.permute(0, 2, 1)
if self.innovate:
inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2)
inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0])))
inno_alpha = inno_alpha.permute(0, 2, 1)
inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2)
inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0])))
inno_x = inno_x.permute(0, 2, 1)
# signal path
y = x.reshape(batch_size, num_frames, -1)
y = alpha * y
if self.innovate:
y = y + inno_alpha * inno_x
return y.reshape(batch_size, 1, num_samples)

View File

@@ -0,0 +1,112 @@
import os
import torch
import numpy as np
def load_lpcnet_features(feature_file, version=2):
if version == 2:
layout = {
'cepstrum': [0,18],
'periods': [18, 19],
'pitch_corr': [19, 20],
'lpc': [20, 36]
}
frame_length = 36
elif version == 1:
layout = {
'cepstrum': [0,18],
'periods': [36, 37],
'pitch_corr': [37, 38],
'lpc': [39, 55],
}
frame_length = 55
else:
raise ValueError(f'unknown feature version: {version}')
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
raw_features = raw_features.reshape((-1, frame_length))
features = torch.cat(
[
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
],
dim=1
)
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
ref_data = np.memmap(reference_data_path, dtype=np.int16)
signal = np.memmap(signal_path, dtype=np.int16)
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
assert len(signal) % 160 == 0
num_frames = len(signal) // 160
mem = np.zeros(1)
for fr in range(len(signal)//160):
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
new_data[:] = 0
N = len(signal) - offset
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
def parse_warpq_scores(output_file):
""" extracts warpq scores from output file """
with open(output_file, "r") as f:
lines = f.readlines()
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
return scores
def parse_stats_file(file):
with open(file, "r") as f:
lines = f.readlines()
mean = float(lines[0].split(":")[-1])
bt_mean = float(lines[1].split(":")[-1])
top_mean = float(lines[2].split(":")[-1])
return mean, bt_mean, top_mean
def collect_test_stats(test_folder):
""" collects statistics for all discovered metrics from test folder """
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
results = dict()
content = os.listdir(test_folder)
stats_files = [file for file in content if file.startswith('stats_')]
for file in stats_files:
metric = file[len("stats_") : -len(".txt")]
if metric not in metrics:
print(f"warning: unknown metric {metric}")
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
results[metric] = [mean, bt_mean, top_mean]
return results

View File

@@ -0,0 +1,95 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch.nn.utils import remove_weight_norm
def count_parameters(model, verbose=False):
total = 0
for name, p in model.named_parameters():
count = torch.ones_like(p).sum().item()
if verbose:
print(f"{name}: {count} parameters")
total += count
return total
def count_nonzero_parameters(model, verbose=False):
total = 0
for name, p in model.named_parameters():
count = torch.count_nonzero(p).item()
if verbose:
print(f"{name}: {count} non-zero parameters")
total += count
return total
def retain_grads(module):
for p in module.parameters():
if p.requires_grad:
p.retain_grad()
def get_grad_norm(module, p=2):
norm = 0
for param in module.parameters():
if param.requires_grad:
norm = norm + (torch.abs(param.grad) ** p).sum()
return norm ** (1/p)
def create_weights(s_real, s_gen, alpha):
weights = []
with torch.no_grad():
for sr, sg in zip(s_real, s_gen):
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
weights.append(weight)
return weights
def _get_candidates(module: torch.nn.Module):
candidates = []
for key in module.__dict__.keys():
if hasattr(module, key + '_v'):
candidates.append(key)
return candidates
def remove_all_weight_norm(model : torch.nn.Module, verbose=False):
for name, m in model.named_modules():
candidates = _get_candidates(m)
for candidate in candidates:
try:
remove_weight_norm(m, name=candidate)
if verbose: print(f'removed weight norm on weight {name}.{candidate}')
except:
pass

View File

@@ -0,0 +1,153 @@
import numpy as np
import scipy.signal
def compute_vad_mask(x, fs, stop_db=-70):
frame_length = (fs + 49) // 50
x = x[: frame_length * (len(x) // frame_length)]
frames = x.reshape(-1, frame_length)
frame_energy = np.sum(frames ** 2, axis=1)
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
vactive = np.ones_like(frames)
vactive[frame_energy_smooth < max_threshold, :] = 0
vactive = vactive.reshape(-1)
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
filter = filter / filter.sum()
mask = np.convolve(vactive, filter, mode='same')
return x, mask
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
num_samples = frame_size + (num_frames - 1) * hop_size
if len(mask) < num_samples:
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
else:
mask = mask[:num_samples]
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
return new_mask
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
num_spectra = (len(x) - window_size - hop_size) // hop_size
window = scipy.signal.get_window(window, window_size)
N = window_size // 2
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
return psd
def frequency_mask(num_bands, up_factor, down_factor):
up_mask = np.zeros((num_bands, num_bands))
down_mask = np.zeros((num_bands, num_bands))
for i in range(num_bands):
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
return down_mask @ up_mask
def rect_fb(band_limits, num_bins=None):
num_bands = len(band_limits) - 1
if num_bins is None:
num_bins = band_limits[-1]
fb = np.zeros((num_bands, num_bins))
for i in range(num_bands):
fb[i, band_limits[i]:band_limits[i+1]] = 1
return fb
def compare(x, y, apply_vad=False):
""" Modified version of opus_compare for 16 kHz mono signals
Args:
x (np.ndarray): reference input signal scaled to [-1, 1]
y (np.ndarray): test signal scaled to [-1, 1]
Returns:
float: perceptually weighted error
"""
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
num_bands = len(band_limits) - 1
fb = rect_fb(band_limits, num_bins=81)
# trim samples to same size
num_samples = min(len(x), len(y))
x = x[:num_samples] * 2**15
y = y[:num_samples] * 2**15
psd_x = power_spectrum(x) + 100000
psd_y = power_spectrum(y) + 100000
num_frames = psd_x.shape[0]
# average band energies
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
# frequecy masking
f_mask = frequency_mask(num_bands, 0.1, 0.03)
mask_x = be_x @ f_mask.T
# temporal masking
for i in range(1, num_frames):
mask_x[i, :] += 0.5 * mask_x[i-1, :]
# apply mask
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
# 2-frame average
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
# distortion metric
re = masked_psd_y / masked_psd_x
im = np.log(re) ** 2
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
Ef = np.mean(Eb , axis=1)
if apply_vad:
_, mask = compute_vad_mask(x, 16000)
mask = convert_mask(mask, Ef.shape[0])
else:
mask = np.ones_like(Ef)
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
return float(err)
if __name__ == "__main__":
import argparse
from scipy.io import wavfile
parser = argparse.ArgumentParser()
parser.add_argument('ref', type=str, help='reference wav file')
parser.add_argument('deg', type=str, help='degraded wav file')
parser.add_argument('--apply-vad', action='store_true')
args = parser.parse_args()
fs1, x = wavfile.read(args.ref)
fs2, y = wavfile.read(args.deg)
if max(fs1, fs2) != 16000:
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
x = x.astype(np.float32) / 2**15
y = y.astype(np.float32) / 2**15
err = compare(x, y, apply_vad=args.apply_vad)
print(f"MOC: {err}")

View File

@@ -0,0 +1,122 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numpy as np
def hangover(lags, num_frames=10):
lags = lags.copy()
count = 0
last_lag = 0
for i in range(len(lags)):
lag = lags[i]
if lag == 0:
if count < num_frames:
lags[i] = last_lag
count += 1
else:
count = 0
last_lag = lag
return lags
def smooth_pitch_lags(lags, d=2):
assert d < 4
num_silk_frames = len(lags) // 4
smoothed_lags = lags.copy()
tmp = np.arange(1, d+1)
kernel = np.concatenate((tmp, [d+1], tmp[::-1]), dtype=np.float32)
kernel = kernel / np.sum(kernel)
last = lags[0:d][::-1]
for i in range(num_silk_frames):
frame = lags[i * 4: (i+1) * 4]
if np.max(np.abs(frame)) == 0:
last = frame[4-d:]
continue
if i == num_silk_frames - 1:
next = frame[4-d:][::-1]
else:
next = lags[(i+1) * 4 : (i+1) * 4 + d]
if np.max(np.abs(next)) == 0:
next = frame[4-d:][::-1]
if np.max(np.abs(last)) == 0:
last = frame[0:d][::-1]
smoothed_frame = np.convolve(np.concatenate((last, frame, next), dtype=np.float32), kernel, mode='valid')
smoothed_lags[i * 4: (i+1) * 4] = np.round(smoothed_frame)
last = frame[4-d:]
return smoothed_lags
def calculate_acorr_window(x, frame_size, lags, history=None, max_lag=300, radius=2, add_double_lag_acorr=False, no_pitch_threshold=32):
eps = 1e-9
lag_multiplier = 2 if add_double_lag_acorr else 1
if history is None:
history = np.zeros(lag_multiplier * max_lag + radius, dtype=x.dtype)
offset = len(history)
assert offset >= max_lag + radius
assert len(x) % frame_size == 0
num_frames = len(x) // frame_size
lags = lags.copy()
x_ext = np.concatenate((history, x), dtype=x.dtype)
d = radius
num_acorrs = 2 * d + 1
acorrs = np.zeros((num_frames, lag_multiplier * num_acorrs), dtype=x.dtype)
for idx in range(num_frames):
lag = lags[idx].item()
frame = x_ext[offset + idx * frame_size : offset + (idx + 1) * frame_size]
for k in range(lag_multiplier):
lag1 = (k + 1) * lag if lag >= no_pitch_threshold else lag
for j in range(num_acorrs):
past = x_ext[offset + idx * frame_size - lag1 + j - d : offset + (idx + 1) * frame_size - lag1 + j - d]
acorrs[idx, j + k * num_acorrs] = np.dot(frame, past) / np.sqrt(np.dot(frame, frame) * np.dot(past, past) + eps)
return acorrs, lags

View File

@@ -0,0 +1,144 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import numpy as np
import torch
import scipy
import scipy.signal
from utils.pitch import hangover, calculate_acorr_window
from utils.spec import create_filter_bank, cepstrum, log_spectrum, log_spectrum_from_lpc
def spec_from_lpc(a, n_fft=128, eps=1e-9):
order = a.shape[-1]
assert order + 1 < n_fft
x = np.zeros((*a.shape[:-1], n_fft ))
x[..., 0] = 1
x[..., 1:1 + order] = -a
X = np.fft.fft(x, axis=-1)
X = np.abs(X[..., :n_fft//2 + 1]) ** 2
S = 1 / (X + eps)
return S
def silk_feature_factory(no_pitch_value=256,
acorr_radius=2,
pitch_hangover=8,
num_bands_clean_spec=64,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
add_double_lag_acorr=False
):
w = scipy.signal.windows.cosine(320)
fb_clean_spec = create_filter_bank(num_bands_clean_spec, 320, scale='erb', round_center_bins=True, normalize=True)
fb_noisy_spec = create_filter_bank(num_bands_noisy_spec, 320, scale=noisy_spec_scale, round_center_bins=True, normalize=True)
def create_features(noisy, noisy_history, lpcs, gains, ltps, periods):
periods = periods.copy()
if pitch_hangover > 0:
periods = hangover(periods, num_frames=pitch_hangover)
periods[periods == 0] = no_pitch_value
clean_spectrum = 0.3 * log_spectrum_from_lpc(lpcs, fb=fb_clean_spec, n_fft=320)
if noisy_apply_dct:
noisy_cepstrum = np.repeat(
cepstrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
else:
noisy_cepstrum = np.repeat(
log_spectrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
log_gains = np.log(gains + 1e-9).reshape(-1, 1)
acorr, _ = calculate_acorr_window(noisy, 80, periods, noisy_history, radius=acorr_radius, add_double_lag_acorr=add_double_lag_acorr)
features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains), axis=-1, dtype=np.float32)
return features, periods.astype(np.int64)
return create_features
def load_inference_data(path,
no_pitch_value=256,
skip=92,
preemph=0.85,
acorr_radius=2,
pitch_hangover=8,
num_bands_clean_spec=64,
num_bands_noisy_spec=18,
noisy_spec_scale='opus',
noisy_apply_dct=True,
add_double_lag_acorr=False,
**kwargs):
print(f"[load_inference_data]: ignoring keyword arguments {kwargs.keys()}...")
lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32).astype(np.float32).reshape(-1, 1)
num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32).reshape(-1, 1)
# load signal, add back delay and pre-emphasize
signal = np.fromfile(os.path.join(path, 'noisy.s16'), dtype=np.int16).astype(np.float32) / (2 ** 15)
signal = np.concatenate((np.zeros(skip, dtype=np.float32), signal), dtype=np.float32)
create_features = silk_feature_factory(no_pitch_value, acorr_radius, pitch_hangover, num_bands_clean_spec, num_bands_noisy_spec, noisy_spec_scale, noisy_apply_dct, add_double_lag_acorr)
num_frames = min((len(signal) // 320) * 4, len(lpcs))
signal = signal[: num_frames * 80]
lpcs = lpcs[: num_frames]
ltps = ltps[: num_frames]
gains = gains[: num_frames]
periods = periods[: num_frames]
num_bits = num_bits[: num_frames // 4]
num_bits_smooth = num_bits[: num_frames // 4]
numbits = np.repeat(np.concatenate((num_bits, num_bits_smooth), axis=-1, dtype=np.float32), 4, axis=0)
features, periods = create_features(signal, np.zeros(350, dtype=signal.dtype), lpcs, gains, ltps, periods)
if preemph > 0:
signal[1:] -= preemph * signal[:-1]
return torch.from_numpy(signal), torch.from_numpy(features), torch.from_numpy(periods), torch.from_numpy(numbits)

View File

@@ -0,0 +1,110 @@
import torch
@torch.no_grad()
def compute_optimal_scale(weight):
with torch.no_grad():
n_out, n_in = weight.shape
assert n_in % 4 == 0
if n_out % 8:
# add padding
pad = n_out - n_out % 8
weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
scale_max = weight_max_abs / 127
scale_sum = weight_max_sum / 129
scale = torch.maximum(scale_max, scale_sum)
return scale[:n_out]
@torch.no_grad()
def q_scaled_noise(module, weight):
if isinstance(module, torch.nn.Conv1d):
w = weight.permute(0, 2, 1).flatten(1)
noise = torch.rand_like(w) - 0.5
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
elif isinstance(module, torch.nn.ConvTranspose1d):
i, o, k = weight.shape
w = weight.permute(2, 1, 0).reshape(k * o, i)
noise = torch.rand_like(w) - 0.5
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(k, o, i).permute(2, 1, 0)
elif len(weight.shape) == 2:
noise = torch.rand_like(weight) - 0.5
scale = compute_optimal_scale(weight)
noise = noise * scale.unsqueeze(-1)
else:
raise ValueError('unknown quantization setting')
return noise
class SoftQuant:
name: str
def __init__(self, names: str, scale: float) -> None:
self.names = names
self.quantization_noise = None
self.scale = scale
def __call__(self, module, inputs, *args, before=True):
if not module.training: return
if before:
self.quantization_noise = dict()
for name in self.names:
weight = getattr(module, name)
if self.scale is None:
self.quantization_noise[name] = q_scaled_noise(module, weight)
else:
self.quantization_noise[name] = \
self.scale * weight.abs().max() * (torch.rand_like(weight) - 0.5)
with torch.no_grad():
weight.data[:] = weight + self.quantization_noise[name]
else:
for name in self.names:
weight = getattr(module, name)
with torch.no_grad():
weight.data[:] = weight - self.quantization_noise[name]
self.quantization_noise = None
def apply(module, names=['weight'], scale=None):
fn = SoftQuant(names, scale)
for name in names:
if not hasattr(module, name):
raise ValueError("")
fn_before = lambda *x : fn(*x, before=True)
fn_after = lambda *x : fn(*x, before=False)
setattr(fn_before, 'sqm', fn)
setattr(fn_after, 'sqm', fn)
module.register_forward_pre_hook(fn_before)
module.register_forward_hook(fn_after)
module
return fn
def soft_quant(module, names=['weight'], scale=None):
fn = SoftQuant.apply(module, names, scale)
return module
def remove_soft_quant(module, names=['weight']):
for k, hook in module._forward_pre_hooks.items():
if hasattr(hook, 'sqm'):
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
del module._forward_pre_hooks[k]
for k, hook in module._forward_hooks.items():
if hasattr(hook, 'sqm'):
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
del module._forward_hooks[k]
return module

View File

@@ -0,0 +1,210 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import math as m
import numpy as np
import scipy
import scipy.fftpack
import torch
def erb(f):
return 24.7 * (4.37 * f + 1)
def inv_erb(e):
return (e / 24.7 - 1) / 4.37
def bark(f):
return 6 * m.asinh(f/600)
def inv_bark(b):
return 600 * m.sinh(b / 6)
scale_dict = {
'bark': [bark, inv_bark],
'erb': [erb, inv_erb]
}
def gen_filterbank(N, Fs=16000, keep_size=False):
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
M = N + 1 if keep_size else N
out_freq = (np.arange(M, dtype='float32')/N*Fs/2)[:,None]
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
ERB_N = 24.7 + .108*in_freq
delta = np.abs(in_freq-out_freq)/ERB_N
center = (delta<.5).astype('float32')
R = -12*center*delta**2 + (1-center)*(3-12*delta)
RE = 10.**(R/10.)
norm = np.sum(RE, axis=1)
RE = RE/norm[:, np.newaxis]
return torch.from_numpy(RE)
def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
f0 = 0
num_bins = n_fft // 2 + 1
f1 = fs / n_fft * (num_bins - 1)
fstep = fs / n_fft
if scale == 'opus':
bins_5ms = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40]
fac = 1000 * n_fft / fs / 5
if num_bands != 18:
print("warning: requested Opus filter bank with num_bands != 18. Adjusting num_bands.")
num_bands = 18
center_bins = np.array([fac * bin for bin in bins_5ms])
else:
to_scale, from_scale = scale_dict[scale]
s0 = to_scale(f0)
s1 = to_scale(f1)
center_freqs = np.array([f0] + [from_scale(s0 + i * (s1 - s0) / (num_bands)) for i in range(1, num_bands - 1)] + [f1])
center_bins = (center_freqs - f0) / fstep
if round_center_bins:
center_bins = np.round(center_bins)
filter_bank = np.zeros((num_bands, num_bins))
band = 0
for bin in range(num_bins):
# update band index
if bin > center_bins[band + 1]:
band += 1
# calculate filter coefficients
frac = (center_bins[band + 1] - bin) / (center_bins[band + 1] - center_bins[band])
filter_bank[band][bin] = frac
filter_bank[band + 1][bin] = 1 - frac
if return_upper:
extend = n_fft - num_bins
filter_bank = np.concatenate((filter_bank, np.fliplr(filter_bank[:, 1:extend+1])), axis=1)
if normalize:
filter_bank = filter_bank / np.sum(filter_bank, axis=1).reshape(-1, 1)
return filter_bank
def compressed_log_spec(pspec):
lpspec = np.zeros_like(pspec)
num_bands = pspec.shape[-1]
log_max = -2
follow = -2
for i in range(num_bands):
tmp = np.log10(pspec[i] + 1e-9)
tmp = max(log_max, max(follow - 2.5, tmp))
lpspec[i] = tmp
log_max = max(log_max, tmp)
follow = max(follow - 2.5, tmp)
return lpspec
def log_spectrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False, power=1):
""" calculates cepstrum from SILK lpcs """
order = a.shape[-1]
assert order + 1 < n_fft
a = a * (gamma ** (1 + np.arange(order))).astype(np.float32)
x = np.zeros((*a.shape[:-1], n_fft ))
x[..., 0] = 1
x[..., 1:1 + order] = -a
X = np.fft.fft(x, axis=-1)
X = np.abs(X[..., :n_fft//2 + 1]) ** power
S = 1 / (X + eps)
if fb is None:
Sf = S
else:
Sf = np.matmul(S, fb.T)
if compress:
Sf = np.apply_along_axis(compressed_log_spec, -1, Sf)
else:
Sf = np.log(Sf + eps)
return Sf
def cepstrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False):
""" calculates cepstrum from SILK lpcs """
Sf = log_spectrum_from_lpc(a, fb, n_fft, eps, gamma, compress)
cepstrum = scipy.fftpack.dct(Sf, 2, norm='ortho')
return cepstrum
def log_spectrum(x, frame_size, fb=None, window=None, power=1):
""" calculate cepstrum on 50% overlapping frames """
assert(2*len(x)) % frame_size == 0
assert frame_size % 2 == 0
n = len(x)
num_even = n // frame_size
num_odd = (n - frame_size // 2) // frame_size
num_bins = frame_size // 2 + 1
x_even = x[:num_even * frame_size].reshape(-1, frame_size)
x_odd = x[frame_size//2 : frame_size//2 + frame_size * num_odd].reshape(-1, frame_size)
x_unfold = np.empty((x_even.size + x_odd.size), dtype=x.dtype).reshape((-1, frame_size))
x_unfold[::2, :] = x_even
x_unfold[1::2, :] = x_odd
if window is not None:
x_unfold *= window.reshape(1, -1)
X = np.abs(np.fft.fft(x_unfold, n=frame_size, axis=-1))[:, :num_bins] ** power
if fb is not None:
X = np.matmul(X, fb.T)
return np.log(X + 1e-9)
def cepstrum(x, frame_size, fb=None, window=None):
""" calculate cepstrum on 50% overlapping frames """
X = log_spectrum(x, frame_size, fb, window)
cepstrum = scipy.fftpack.dct(X, 2, norm='ortho')
return cepstrum

View File

@@ -0,0 +1,347 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
setup_dict = dict()
lace_setup = {
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
'model': {
'name': 'lace',
'args': [],
'kwargs': {
'comb_gain_limit_db': 10,
'cond_dim': 128,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'hidden_feature_dim': 96,
'kernel_size': 15,
'num_features': 93,
'numbits_embedding_dim': 8,
'numbits_range': [50, 650],
'partial_lookahead': True,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'skip': 91,
'softquant': True,
'sparsify': False,
'sparsification_density': 0.4,
'sparsification_schedule': [10000, 40000, 200]
}
},
'data': {
'frames_per_sample': 100,
'no_pitch_value': 7,
'preemph': 0.85,
'skip': 91,
'pitch_hangover': 8,
'acorr_radius': 2,
'num_bands_clean_spec': 64,
'num_bands_noisy_spec': 18,
'noisy_spec_scale': 'opus',
'pitch_hangover': 0,
},
'training': {
'batch_size': 256,
'lr': 5.e-4,
'lr_decay_factor': 2.5e-5,
'epochs': 50,
'loss': {
'w_l1': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_wsc': 0,
'w_xcorr': 0,
'w_sxcorr': 1,
'w_l2': 10,
'w_slm': 2
}
}
}
nolace_setup = {
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
'model': {
'name': 'nolace',
'args': [],
'kwargs': {
'avg_pool_k': 4,
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'hidden_feature_dim': 96,
'kernel_size': 15,
'num_features': 93,
'numbits_embedding_dim': 8,
'numbits_range': [50, 650],
'partial_lookahead': True,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'skip': 91,
'softquant': True,
'sparsify': False,
'sparsification_density': 0.4,
'sparsification_schedule': [10000, 40000, 200]
}
},
'data': {
'frames_per_sample': 100,
'no_pitch_value': 7,
'preemph': 0.85,
'skip': 91,
'pitch_hangover': 8,
'acorr_radius': 2,
'num_bands_clean_spec': 64,
'num_bands_noisy_spec': 18,
'noisy_spec_scale': 'opus',
'pitch_hangover': 0,
},
'training': {
'batch_size': 256,
'lr': 5.e-4,
'lr_decay_factor': 2.5e-5,
'epochs': 50,
'loss': {
'w_l1': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_wsc': 0,
'w_xcorr': 0,
'w_sxcorr': 1,
'w_l2': 10,
'w_slm': 2
}
}
}
nolace_setup_adv = {
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
'model': {
'name': 'nolace',
'args': [],
'kwargs': {
'avg_pool_k': 4,
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'hidden_feature_dim': 96,
'kernel_size': 15,
'num_features': 93,
'numbits_embedding_dim': 8,
'numbits_range': [50, 650],
'partial_lookahead': True,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'skip': 91,
'softquant': True,
'sparsify': False,
'sparsification_density': 0.4,
'sparsification_schedule': [0, 0, 200]
}
},
'data': {
'frames_per_sample': 100,
'no_pitch_value': 7,
'preemph': 0.85,
'skip': 91,
'pitch_hangover': 8,
'acorr_radius': 2,
'num_bands_clean_spec': 64,
'num_bands_noisy_spec': 18,
'noisy_spec_scale': 'opus',
'pitch_hangover': 0,
},
'discriminator': {
'args': [],
'kwargs': {
'architecture': 'free',
'design': 'f_down',
'fft_sizes_16k': [
64,
128,
256,
512,
1024,
2048,
],
'freq_roi': [0, 7400],
'fs': 16000,
'max_channels': 256,
'noise_gain': 0.0,
},
'name': 'fdmresdisc',
},
'training': {
'adv_target': 'target_orig',
'batch_size': 64,
'epochs': 50,
'gen_lr_reduction': 1,
'lambda_feat': 1.0,
'lambda_reg': 0.6,
'loss': {
'w_l1': 0,
'w_l2': 10,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_slm': 20,
'w_sxcorr': 1,
'w_wsc': 0,
'w_xcorr': 0,
},
'lr': 0.0001,
'lr_decay_factor': 2.5e-09,
}
}
lavoce_setup = {
'data': {
'frames_per_sample': 100,
'target': 'signal'
},
'dataset': '/local/datasets/lpcnet_large/training',
'model': {
'args': [],
'kwargs': {
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'kernel_size': 15,
'num_features': 19,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'pulses': True
},
'name': 'lavoce'
},
'training': {
'batch_size': 256,
'epochs': 50,
'loss': {
'w_l1': 0,
'w_l2': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_slm': 2,
'w_sxcorr': 1,
'w_wsc': 0,
'w_xcorr': 0
},
'lr': 0.0005,
'lr_decay_factor': 2.5e-05
},
'validation_dataset': '/local/datasets/lpcnet_large/validation'
}
lavoce_setup_adv = {
'data': {
'frames_per_sample': 100,
'target': 'signal'
},
'dataset': '/local/datasets/lpcnet_large/training',
'discriminator': {
'args': [],
'kwargs': {
'architecture': 'free',
'design': 'f_down',
'fft_sizes_16k': [
64,
128,
256,
512,
1024,
2048,
],
'freq_roi': [0, 7400],
'fs': 16000,
'max_channels': 256,
'noise_gain': 0.0,
},
'name': 'fdmresdisc',
},
'model': {
'args': [],
'kwargs': {
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'kernel_size': 15,
'num_features': 19,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'pulses': True
},
'name': 'lavoce'
},
'training': {
'batch_size': 64,
'epochs': 50,
'gen_lr_reduction': 1,
'lambda_feat': 1.0,
'lambda_reg': 0.6,
'loss': {
'w_l1': 0,
'w_l2': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_slm': 2,
'w_sxcorr': 1,
'w_wsc': 0,
'w_xcorr': 0
},
'lr': 0.0001,
'lr_decay_factor': 2.5e-09
},
}
setup_dict = {
'lace': lace_setup,
'nolace': nolace_setup,
'nolace_adv': nolace_setup_adv,
'lavoce': lavoce_setup,
'lavoce_adv': lavoce_setup_adv
}