add some code
This commit is contained in:
66
managed_components/78__esp-opus/dnn/torch/osce/README.md
Normal file
66
managed_components/78__esp-opus/dnn/torch/osce/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Opus Speech Coding Enhancement
|
||||
|
||||
This folder hosts models for enhancing Opus SILK.
|
||||
|
||||
## Environment setup
|
||||
The code is tested with python 3.11. Conda setup is done via
|
||||
|
||||
|
||||
`conda create -n osce python=3.11`
|
||||
|
||||
`conda activate osce`
|
||||
|
||||
`python -m pip install -r requirements.txt`
|
||||
|
||||
|
||||
## Generating training data
|
||||
First step is to convert all training items to 16 kHz and 16 bit pcm and then concatenate them. A convenient way to do this is to create a file list and then run
|
||||
|
||||
`python scripts/concatenator.py filelist 16000 dataset/clean.s16 --db_min -40 --db_max 0`
|
||||
|
||||
which on top provides some random scaling. Data is taken from the datasets listed in dnn/datasets.txt and the exact list of items used for training and validation is
|
||||
located in dnn/torch/osce/resources.
|
||||
|
||||
Second step is to run a patched version of opus_demo in the dataset folder, which will produce the coded output and add feature files. To build the patched opus_demo binary, check out the exp-neural-silk-enhancement branch and build opus_demo the usual way. Then run
|
||||
|
||||
`cd dataset && <path_to_patched_opus_demo>/opus_demo voip 16000 1 9000 -silk_random_switching 249 clean.s16 coded.s16 `
|
||||
|
||||
The argument to -silk_random_switching specifies the number of frames after which parameters are switched randomly.
|
||||
|
||||
## Regression loss based training
|
||||
Create a default setup for LACE or NoLACE via
|
||||
|
||||
`python make_default_setup.py model.yml --model lace/nolace --path2dataset <path2dataset>`
|
||||
|
||||
Then run
|
||||
|
||||
`python train_model.py model.yml <output folder> --no-redirect`
|
||||
|
||||
for running the training script in foreground or
|
||||
|
||||
`nohup python train_model.py model.yml <output folder> &`
|
||||
|
||||
to run it in background. In the latter case the output is written to `<output folder>/out.txt`.
|
||||
|
||||
## Adversarial training (NoLACE only)
|
||||
Create a default setup for NoLACE via
|
||||
|
||||
`python make_default_setup.py nolace_adv.yml --model nolace --adversarial --path2dataset <path2dataset>`
|
||||
|
||||
Then run
|
||||
|
||||
`python adv_train_model.py nolace_adv.yml <output folder> --no-redirect`
|
||||
|
||||
for running the training script in foreground or
|
||||
|
||||
`nohup python adv_train_model.py nolace_adv.yml <output folder> &`
|
||||
|
||||
to run it in background. In the latter case the output is written to `<output folder>/out.txt`.
|
||||
|
||||
## Inference
|
||||
Generating inference data is analogous to generating training data. Given an item 'item1.wav' run
|
||||
`mkdir item1.se && sox item1.wav -r 16000 -e signed-integer -b 16 item1.raw && cd item1.se && <path_to_patched_opus_demo>/opus_demo voip 16000 1 <bitrate> ../item1.raw noisy.s16`
|
||||
|
||||
The folder item1.se then serves as input for the test_model.py script or for the --testdata argument of train_model.py resp. adv_train_model.py
|
||||
|
||||
autogen.sh downloads pre-trained model weights to the subfolder dnn/models of the main repo.
|
||||
@@ -0,0 +1,462 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
import math as m
|
||||
import random
|
||||
|
||||
import yaml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
import git
|
||||
has_git = True
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scipy.io import wavfile
|
||||
import numpy as np
|
||||
import pesq
|
||||
|
||||
from data import SilkEnhancementSet
|
||||
from models import model_dict
|
||||
|
||||
|
||||
from utils.silk_features import load_inference_data
|
||||
from utils.misc import count_parameters, retain_grads, get_grad_norm, create_weights
|
||||
|
||||
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||
parser.add_argument('output', type=str, help='output path')
|
||||
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||
parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
|
||||
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
checkpoint_prefix = 'checkpoint'
|
||||
output_prefix = 'output'
|
||||
setup_name = 'setup.yml'
|
||||
output_file='out.txt'
|
||||
|
||||
|
||||
# check model
|
||||
if not 'name' in setup['model']:
|
||||
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
|
||||
model_name = 'pitchpostfilter'
|
||||
else:
|
||||
model_name = setup['model']['name']
|
||||
|
||||
# prepare output folder
|
||||
if os.path.exists(args.output):
|
||||
print("warning: output folder exists")
|
||||
|
||||
reply = input('continue? (y/n): ')
|
||||
while reply not in {'y', 'n'}:
|
||||
reply = input('continue? (y/n): ')
|
||||
|
||||
if reply == 'n':
|
||||
os._exit()
|
||||
else:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# add repo info to setup
|
||||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir, search_parent_directories=True)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
is_dirty = repo.is_dirty()
|
||||
|
||||
if is_dirty:
|
||||
print("warning: repo is dirty")
|
||||
|
||||
setup['repo']['hash'] = hash
|
||||
setup['repo']['urls'] = urls
|
||||
setup['repo']['dirty'] = is_dirty
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
# dump setup
|
||||
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||
yaml.dump(setup, f)
|
||||
|
||||
|
||||
ref = None
|
||||
if args.testdata is not None:
|
||||
|
||||
testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
|
||||
|
||||
inference_test = True
|
||||
inference_folder = os.path.join(args.output, 'inference_test')
|
||||
os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
|
||||
|
||||
try:
|
||||
ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
inference_test = False
|
||||
|
||||
# training parameters
|
||||
batch_size = setup['training']['batch_size']
|
||||
epochs = setup['training']['epochs']
|
||||
lr = setup['training']['lr']
|
||||
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||
lr_gen = lr * setup['training']['gen_lr_reduction']
|
||||
lambda_feat = setup['training']['lambda_feat']
|
||||
lambda_reg = setup['training']['lambda_reg']
|
||||
adv_target = setup['training'].get('adv_target', 'target')
|
||||
|
||||
# load training dataset
|
||||
data_config = setup['data']
|
||||
data = SilkEnhancementSet(setup['dataset'], **data_config)
|
||||
|
||||
# load validation dataset if given
|
||||
if 'validation_dataset' in setup:
|
||||
validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
|
||||
|
||||
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
|
||||
|
||||
run_validation = True
|
||||
else:
|
||||
run_validation = False
|
||||
|
||||
# create model
|
||||
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
|
||||
|
||||
# create discriminator
|
||||
disc_name = setup['discriminator']['name']
|
||||
disc = model_dict[disc_name](
|
||||
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
|
||||
)
|
||||
|
||||
# set compute device
|
||||
if type(args.device) == type(None):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
# dataloader
|
||||
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
|
||||
|
||||
# optimizer is introduced to trainable parameters
|
||||
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
|
||||
|
||||
# disc optimizer
|
||||
parameters = [p for p in disc.parameters() if p.requires_grad]
|
||||
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||
|
||||
if args.initial_checkpoint is not None:
|
||||
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||
chkpt = torch.load(args.initial_checkpoint, map_location=device)
|
||||
model.load_state_dict(chkpt['state_dict'])
|
||||
|
||||
if 'disc_state_dict' in chkpt:
|
||||
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
|
||||
disc.load_state_dict(chkpt['disc_state_dict'])
|
||||
|
||||
if 'optimizer_state_dict' in chkpt:
|
||||
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
|
||||
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
|
||||
|
||||
if 'disc_optimizer_state_dict' in chkpt:
|
||||
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
|
||||
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
|
||||
|
||||
if 'scheduler_state_disc' in chkpt:
|
||||
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
|
||||
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
|
||||
|
||||
# if 'torch_rng_state' in chkpt:
|
||||
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
|
||||
# torch.set_rng_state(chkpt['torch_rng_state'])
|
||||
|
||||
if 'numpy_rng_state' in chkpt:
|
||||
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
|
||||
np.random.set_state(chkpt['numpy_rng_state'])
|
||||
|
||||
if 'python_rng_state' in chkpt:
|
||||
print(f"setting Python RNG state from {args.initial_checkpoint}...")
|
||||
random.setstate(chkpt['python_rng_state'])
|
||||
|
||||
# loss
|
||||
w_l1 = setup['training']['loss']['w_l1']
|
||||
w_lm = setup['training']['loss']['w_lm']
|
||||
w_slm = setup['training']['loss']['w_slm']
|
||||
w_sc = setup['training']['loss']['w_sc']
|
||||
w_logmel = setup['training']['loss']['w_logmel']
|
||||
w_wsc = setup['training']['loss']['w_wsc']
|
||||
w_xcorr = setup['training']['loss']['w_xcorr']
|
||||
w_sxcorr = setup['training']['loss']['w_sxcorr']
|
||||
w_l2 = setup['training']['loss']['w_l2']
|
||||
|
||||
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
|
||||
|
||||
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
|
||||
logmelloss = MRLogMelLoss().to(device)
|
||||
|
||||
def xcorr_loss(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return torch.mean(loss)
|
||||
|
||||
def td_l2_norm(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def td_l1(y_true, y_pred, pow=0):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
|
||||
|
||||
return torch.mean(tmp)
|
||||
|
||||
def criterion(x, y):
|
||||
|
||||
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
|
||||
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
|
||||
|
||||
|
||||
# model checkpoint
|
||||
checkpoint = {
|
||||
'setup' : setup,
|
||||
'state_dict' : model.state_dict(),
|
||||
'loss' : -1
|
||||
}
|
||||
|
||||
|
||||
if not args.no_redirect:
|
||||
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||
|
||||
|
||||
print("summary:")
|
||||
|
||||
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
|
||||
if hasattr(model, 'flop_count'):
|
||||
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
|
||||
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
|
||||
|
||||
if ref is not None:
|
||||
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
|
||||
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
|
||||
print(f"initial MOS (PESQ): {initial_mos}")
|
||||
|
||||
best_loss = 1e9
|
||||
log_interval = 10
|
||||
|
||||
|
||||
m_r = 0
|
||||
m_f = 0
|
||||
s_r = 1
|
||||
s_f = 1
|
||||
|
||||
def optimizer_to(optim, device):
|
||||
for param in optim.state.values():
|
||||
if isinstance(param, torch.Tensor):
|
||||
param.data = param.data.to(device)
|
||||
if param._grad is not None:
|
||||
param._grad.data = param._grad.data.to(device)
|
||||
elif isinstance(param, dict):
|
||||
for subparam in param.values():
|
||||
if isinstance(subparam, torch.Tensor):
|
||||
subparam.data = subparam.data.to(device)
|
||||
if subparam._grad is not None:
|
||||
subparam._grad.data = subparam._grad.data.to(device)
|
||||
|
||||
optimizer_to(optimizer, device)
|
||||
optimizer_to(optimizer_disc, device)
|
||||
|
||||
retain_grads(model)
|
||||
retain_grads(disc)
|
||||
|
||||
for ep in range(1, epochs + 1):
|
||||
print(f"training epoch {ep}...")
|
||||
|
||||
model.to(device)
|
||||
disc.to(device)
|
||||
model.train()
|
||||
disc.train()
|
||||
|
||||
running_disc_loss = 0
|
||||
running_adv_loss = 0
|
||||
running_feature_loss = 0
|
||||
running_reg_loss = 0
|
||||
running_disc_grad_norm = 0
|
||||
running_model_grad_norm = 0
|
||||
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# set gradients to zero
|
||||
optimizer.zero_grad()
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target'].to(device)
|
||||
disc_target = batch[adv_target].to(device)
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
|
||||
|
||||
# discriminator update
|
||||
scores_gen = disc(output.detach())
|
||||
scores_real = disc(disc_target.unsqueeze(1))
|
||||
|
||||
disc_loss = 0
|
||||
for score in scores_gen:
|
||||
disc_loss += (((score[-1]) ** 2)).mean()
|
||||
m_f = 0.9 * m_f + 0.1 * score[-1].detach().mean().cpu().item()
|
||||
s_f = 0.9 * s_f + 0.1 * score[-1].detach().std().cpu().item()
|
||||
|
||||
for score in scores_real:
|
||||
disc_loss += (((1 - score[-1]) ** 2)).mean()
|
||||
m_r = 0.9 * m_r + 0.1 * score[-1].detach().mean().cpu().item()
|
||||
s_r = 0.9 * s_r + 0.1 * score[-1].detach().std().cpu().item()
|
||||
|
||||
disc_loss = 0.5 * disc_loss / len(scores_gen)
|
||||
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
|
||||
|
||||
disc.zero_grad()
|
||||
disc_loss.backward()
|
||||
|
||||
running_disc_grad_norm += get_grad_norm(disc).detach().cpu().item()
|
||||
|
||||
optimizer_disc.step()
|
||||
|
||||
# generator update
|
||||
scores_gen = disc(output)
|
||||
|
||||
# calculate loss
|
||||
loss_reg = criterion(output.squeeze(1), target)
|
||||
|
||||
num_discs = len(scores_gen)
|
||||
gen_loss = 0
|
||||
for score in scores_gen:
|
||||
gen_loss += (((1 - score[-1]) ** 2)).mean() / num_discs
|
||||
|
||||
loss_feat = 0
|
||||
for k in range(num_discs):
|
||||
num_layers = len(scores_gen[k]) - 1
|
||||
f = 4 / num_discs / num_layers
|
||||
for l in range(num_layers):
|
||||
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
|
||||
|
||||
model.zero_grad()
|
||||
|
||||
(gen_loss + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
# sparsification
|
||||
if hasattr(model, 'sparsifier'):
|
||||
model.sparsifier()
|
||||
|
||||
running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
|
||||
running_adv_loss += gen_loss.detach().cpu().item()
|
||||
running_disc_loss += disc_loss.detach().cpu().item()
|
||||
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
|
||||
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
|
||||
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
|
||||
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
|
||||
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
|
||||
model_gradnorm=f"{running_model_grad_norm/(i+1):8.7f}",
|
||||
disc_gradnorm=f"{running_disc_grad_norm/(i+1):8.7f}",
|
||||
wc=f"{100*winning_chance:5.2f}%")
|
||||
|
||||
|
||||
# save checkpoint
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['disc_state_dict'] = disc.state_dict()
|
||||
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
||||
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
|
||||
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
|
||||
checkpoint['torch_rng_state'] = torch.get_rng_state()
|
||||
checkpoint['numpy_rng_state'] = np.random.get_state()
|
||||
checkpoint['python_rng_state'] = random.getstate()
|
||||
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
|
||||
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
|
||||
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
|
||||
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
|
||||
|
||||
|
||||
if inference_test:
|
||||
print("running inference test...")
|
||||
out = model.process(testsignal, features, periods, numbits).cpu().numpy()
|
||||
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
|
||||
if ref is not None:
|
||||
mos = pesq.pesq(16000, ref, out, mode='wb')
|
||||
print(f"MOS (PESQ): {mos}")
|
||||
|
||||
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||
|
||||
|
||||
print()
|
||||
|
||||
print('Done')
|
||||
@@ -0,0 +1,451 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
import math as m
|
||||
import random
|
||||
|
||||
import yaml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
import git
|
||||
has_git = True
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scipy.io import wavfile
|
||||
import numpy as np
|
||||
import pesq
|
||||
|
||||
from data import LPCNetVocodingDataset
|
||||
from models import model_dict
|
||||
|
||||
|
||||
from utils.lpcnet_features import load_lpcnet_features
|
||||
from utils.misc import count_parameters
|
||||
|
||||
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||
parser.add_argument('output', type=str, help='output path')
|
||||
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||
parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
|
||||
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
checkpoint_prefix = 'checkpoint'
|
||||
output_prefix = 'output'
|
||||
setup_name = 'setup.yml'
|
||||
output_file='out.txt'
|
||||
|
||||
|
||||
# check model
|
||||
if not 'name' in setup['model']:
|
||||
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
|
||||
model_name = 'pitchpostfilter'
|
||||
else:
|
||||
model_name = setup['model']['name']
|
||||
|
||||
# prepare output folder
|
||||
if os.path.exists(args.output):
|
||||
print("warning: output folder exists")
|
||||
|
||||
reply = input('continue? (y/n): ')
|
||||
while reply not in {'y', 'n'}:
|
||||
reply = input('continue? (y/n): ')
|
||||
|
||||
if reply == 'n':
|
||||
os._exit()
|
||||
else:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# add repo info to setup
|
||||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir, search_parent_directories=True)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
is_dirty = repo.is_dirty()
|
||||
|
||||
if is_dirty:
|
||||
print("warning: repo is dirty")
|
||||
|
||||
setup['repo']['hash'] = hash
|
||||
setup['repo']['urls'] = urls
|
||||
setup['repo']['dirty'] = is_dirty
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
# dump setup
|
||||
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||
yaml.dump(setup, f)
|
||||
|
||||
|
||||
ref = None
|
||||
# prepare inference test if wanted
|
||||
inference_test = False
|
||||
if type(args.test_features) != type(None):
|
||||
test_features = load_lpcnet_features(args.test_features)
|
||||
features = test_features['features']
|
||||
periods = test_features['periods']
|
||||
inference_folder = os.path.join(args.output, 'inference_test')
|
||||
os.makedirs(inference_folder, exist_ok=True)
|
||||
inference_test = True
|
||||
|
||||
|
||||
# training parameters
|
||||
batch_size = setup['training']['batch_size']
|
||||
epochs = setup['training']['epochs']
|
||||
lr = setup['training']['lr']
|
||||
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||
lr_gen = lr * setup['training']['gen_lr_reduction']
|
||||
lambda_feat = setup['training']['lambda_feat']
|
||||
lambda_reg = setup['training']['lambda_reg']
|
||||
adv_target = setup['training'].get('adv_target', 'target')
|
||||
|
||||
|
||||
# load training dataset
|
||||
data_config = setup['data']
|
||||
data = LPCNetVocodingDataset(setup['dataset'], **data_config)
|
||||
|
||||
# load validation dataset if given
|
||||
if 'validation_dataset' in setup:
|
||||
validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
|
||||
|
||||
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
|
||||
|
||||
run_validation = True
|
||||
else:
|
||||
run_validation = False
|
||||
|
||||
# create model
|
||||
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
|
||||
|
||||
|
||||
# create discriminator
|
||||
disc_name = setup['discriminator']['name']
|
||||
disc = model_dict[disc_name](
|
||||
*setup['discriminator']['args'], **setup['discriminator']['kwargs']
|
||||
)
|
||||
|
||||
|
||||
|
||||
# set compute device
|
||||
if type(args.device) == type(None):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
|
||||
|
||||
# dataloader
|
||||
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
|
||||
|
||||
# optimizer is introduced to trainable parameters
|
||||
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.Adam(parameters, lr=lr_gen)
|
||||
|
||||
# disc optimizer
|
||||
parameters = [p for p in disc.parameters() if p.requires_grad]
|
||||
optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9])
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||
|
||||
if args.initial_checkpoint is not None:
|
||||
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||
chkpt = torch.load(args.initial_checkpoint, map_location=device)
|
||||
model.load_state_dict(chkpt['state_dict'])
|
||||
|
||||
if 'disc_state_dict' in chkpt:
|
||||
print(f"loading discriminator state dict from {args.initial_checkpoint}...")
|
||||
disc.load_state_dict(chkpt['disc_state_dict'])
|
||||
|
||||
if 'optimizer_state_dict' in chkpt:
|
||||
print(f"loading optimizer state dict from {args.initial_checkpoint}...")
|
||||
optimizer.load_state_dict(chkpt['optimizer_state_dict'])
|
||||
|
||||
if 'disc_optimizer_state_dict' in chkpt:
|
||||
print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...")
|
||||
optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict'])
|
||||
|
||||
if 'scheduler_state_disc' in chkpt:
|
||||
print(f"loading scheduler state dict from {args.initial_checkpoint}...")
|
||||
scheduler.load_state_dict(chkpt['scheduler_state_dict'])
|
||||
|
||||
# if 'torch_rng_state' in chkpt:
|
||||
# print(f"setting torch RNG state from {args.initial_checkpoint}...")
|
||||
# torch.set_rng_state(chkpt['torch_rng_state'])
|
||||
|
||||
if 'numpy_rng_state' in chkpt:
|
||||
print(f"setting numpy RNG state from {args.initial_checkpoint}...")
|
||||
np.random.set_state(chkpt['numpy_rng_state'])
|
||||
|
||||
if 'python_rng_state' in chkpt:
|
||||
print(f"setting Python RNG state from {args.initial_checkpoint}...")
|
||||
random.setstate(chkpt['python_rng_state'])
|
||||
|
||||
# loss
|
||||
w_l1 = setup['training']['loss']['w_l1']
|
||||
w_lm = setup['training']['loss']['w_lm']
|
||||
w_slm = setup['training']['loss']['w_slm']
|
||||
w_sc = setup['training']['loss']['w_sc']
|
||||
w_logmel = setup['training']['loss']['w_logmel']
|
||||
w_wsc = setup['training']['loss']['w_wsc']
|
||||
w_xcorr = setup['training']['loss']['w_xcorr']
|
||||
w_sxcorr = setup['training']['loss']['w_sxcorr']
|
||||
w_l2 = setup['training']['loss']['w_l2']
|
||||
|
||||
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
|
||||
|
||||
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
|
||||
logmelloss = MRLogMelLoss().to(device)
|
||||
|
||||
def xcorr_loss(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return torch.mean(loss)
|
||||
|
||||
def td_l2_norm(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def td_l1(y_true, y_pred, pow=0):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
|
||||
|
||||
return torch.mean(tmp)
|
||||
|
||||
def criterion(x, y):
|
||||
|
||||
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
|
||||
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
|
||||
|
||||
|
||||
# model checkpoint
|
||||
checkpoint = {
|
||||
'setup' : setup,
|
||||
'state_dict' : model.state_dict(),
|
||||
'loss' : -1
|
||||
}
|
||||
|
||||
|
||||
if not args.no_redirect:
|
||||
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||
|
||||
|
||||
print("summary:")
|
||||
|
||||
print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
|
||||
if hasattr(model, 'flop_count'):
|
||||
print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS")
|
||||
print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters")
|
||||
|
||||
if ref is not None:
|
||||
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
|
||||
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
|
||||
print(f"initial MOS (PESQ): {initial_mos}")
|
||||
|
||||
best_loss = 1e9
|
||||
log_interval = 10
|
||||
|
||||
|
||||
m_r = 0
|
||||
m_f = 0
|
||||
s_r = 1
|
||||
s_f = 1
|
||||
|
||||
def optimizer_to(optim, device):
|
||||
for param in optim.state.values():
|
||||
if isinstance(param, torch.Tensor):
|
||||
param.data = param.data.to(device)
|
||||
if param._grad is not None:
|
||||
param._grad.data = param._grad.data.to(device)
|
||||
elif isinstance(param, dict):
|
||||
for subparam in param.values():
|
||||
if isinstance(subparam, torch.Tensor):
|
||||
subparam.data = subparam.data.to(device)
|
||||
if subparam._grad is not None:
|
||||
subparam._grad.data = subparam._grad.data.to(device)
|
||||
|
||||
optimizer_to(optimizer, device)
|
||||
optimizer_to(optimizer_disc, device)
|
||||
|
||||
|
||||
for ep in range(1, epochs + 1):
|
||||
print(f"training epoch {ep}...")
|
||||
|
||||
model.to(device)
|
||||
disc.to(device)
|
||||
model.train()
|
||||
disc.train()
|
||||
|
||||
running_disc_loss = 0
|
||||
running_adv_loss = 0
|
||||
running_feature_loss = 0
|
||||
running_reg_loss = 0
|
||||
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# set gradients to zero
|
||||
optimizer.zero_grad()
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target'].to(device)
|
||||
disc_target = batch[adv_target].to(device)
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['features'], batch['periods'])
|
||||
|
||||
# discriminator update
|
||||
scores_gen = disc(output.detach())
|
||||
scores_real = disc(disc_target.unsqueeze(1))
|
||||
|
||||
disc_loss = 0
|
||||
for scale in scores_gen:
|
||||
disc_loss += ((scale[-1]) ** 2).mean()
|
||||
m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item()
|
||||
s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item()
|
||||
|
||||
for scale in scores_real:
|
||||
disc_loss += ((1 - scale[-1]) ** 2).mean()
|
||||
m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item()
|
||||
s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item()
|
||||
|
||||
disc_loss = 0.5 * disc_loss / len(scores_gen)
|
||||
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
|
||||
|
||||
disc.zero_grad()
|
||||
disc_loss.backward()
|
||||
optimizer_disc.step()
|
||||
|
||||
# generator update
|
||||
scores_gen = disc(output)
|
||||
|
||||
|
||||
# calculate loss
|
||||
loss_reg = criterion(output.squeeze(1), target)
|
||||
|
||||
num_discs = len(scores_gen)
|
||||
loss_gen = 0
|
||||
for scale in scores_gen:
|
||||
loss_gen += ((1 - scale[-1]) ** 2).mean() / num_discs
|
||||
|
||||
loss_feat = 0
|
||||
for k in range(num_discs):
|
||||
num_layers = len(scores_gen[k]) - 1
|
||||
f = 4 / num_discs / num_layers
|
||||
for l in range(num_layers):
|
||||
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
|
||||
|
||||
model.zero_grad()
|
||||
|
||||
(loss_gen + lambda_feat * loss_feat + lambda_reg * loss_reg).backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
running_adv_loss += loss_gen.detach().cpu().item()
|
||||
running_disc_loss += disc_loss.detach().cpu().item()
|
||||
running_feature_loss += lambda_feat * loss_feat.detach().cpu().item()
|
||||
running_reg_loss += lambda_reg * loss_reg.detach().cpu().item()
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}",
|
||||
disc_loss=f"{running_disc_loss/(i + 1):8.7f}",
|
||||
feat_loss=f"{running_feature_loss/(i + 1):8.7f}",
|
||||
reg_loss=f"{running_reg_loss/(i + 1):8.7f}",
|
||||
wc=f"{100*winning_chance:5.2f}%")
|
||||
|
||||
|
||||
# save checkpoint
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['disc_state_dict'] = disc.state_dict()
|
||||
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
||||
checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict()
|
||||
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
|
||||
checkpoint['torch_rng_state'] = torch.get_rng_state()
|
||||
checkpoint['numpy_rng_state'] = np.random.get_state()
|
||||
checkpoint['python_rng_state'] = random.getstate()
|
||||
checkpoint['adv_loss'] = running_adv_loss/(i + 1)
|
||||
checkpoint['disc_loss'] = running_disc_loss/(i + 1)
|
||||
checkpoint['feature_loss'] = running_feature_loss/(i + 1)
|
||||
checkpoint['reg_loss'] = running_reg_loss/(i + 1)
|
||||
|
||||
|
||||
if inference_test:
|
||||
print("running inference test...")
|
||||
out = model.process(features, periods).cpu().numpy()
|
||||
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
|
||||
if ref is not None:
|
||||
mos = pesq.pesq(16000, ref, out, mode='wb')
|
||||
print(f"MOS (PESQ): {mos}")
|
||||
|
||||
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||
|
||||
|
||||
print()
|
||||
|
||||
print('Done')
|
||||
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from models import model_dict
|
||||
from utils import endoscopy
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('checkpoint_path', type=str, help='path to folder containing checkpoints "lace_checkpoint.pth" and nolace_checkpoint.pth"')
|
||||
parser.add_argument('output_folder', type=str, help='output folder for testvectors')
|
||||
parser.add_argument('--debug', action='store_true', help='add debug output to output folder')
|
||||
|
||||
|
||||
def create_adaconv_testvector(prefix, adaconv, num_frames, debug=False):
|
||||
feature_dim = adaconv.feature_dim
|
||||
in_channels = adaconv.in_channels
|
||||
out_channels = adaconv.out_channels
|
||||
frame_size = adaconv.frame_size
|
||||
|
||||
features = torch.randn((1, num_frames, feature_dim))
|
||||
x_in = torch.randn((1, in_channels, num_frames * frame_size))
|
||||
|
||||
x_out = adaconv(x_in, features, debug=debug)
|
||||
|
||||
features = features[0].detach().numpy()
|
||||
x_in = x_in[0].reshape(in_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
|
||||
x_out = x_out[0].reshape(out_channels, num_frames, frame_size).permute(1, 0, 2).detach().numpy()
|
||||
|
||||
features.tofile(prefix + '_features.f32')
|
||||
x_in.tofile(prefix + '_x_in.f32')
|
||||
x_out.tofile(prefix + '_x_out.f32')
|
||||
|
||||
def create_adacomb_testvector(prefix, adacomb, num_frames, debug=False):
|
||||
feature_dim = adacomb.feature_dim
|
||||
in_channels = 1
|
||||
frame_size = adacomb.frame_size
|
||||
|
||||
features = torch.randn((1, num_frames, feature_dim))
|
||||
x_in = torch.randn((1, in_channels, num_frames * frame_size))
|
||||
p_in = torch.randint(adacomb.kernel_size, 250, (1, num_frames))
|
||||
|
||||
x_out = adacomb(x_in, features, p_in, debug=debug)
|
||||
|
||||
features = features[0].detach().numpy()
|
||||
x_in = x_in[0].permute(1, 0).detach().numpy()
|
||||
p_in = p_in[0].detach().numpy().astype(np.int32)
|
||||
x_out = x_out[0].permute(1, 0).detach().numpy()
|
||||
|
||||
features.tofile(prefix + '_features.f32')
|
||||
x_in.tofile(prefix + '_x_in.f32')
|
||||
p_in.tofile(prefix + '_p_in.s32')
|
||||
x_out.tofile(prefix + '_x_out.f32')
|
||||
|
||||
def create_adashape_testvector(prefix, adashape, num_frames):
|
||||
feature_dim = adashape.feature_dim
|
||||
frame_size = adashape.frame_size
|
||||
|
||||
features = torch.randn((1, num_frames, feature_dim))
|
||||
x_in = torch.randn((1, 1, num_frames * frame_size))
|
||||
|
||||
x_out = adashape(x_in, features)
|
||||
|
||||
features = features[0].detach().numpy()
|
||||
x_in = x_in.flatten().detach().numpy()
|
||||
x_out = x_out.flatten().detach().numpy()
|
||||
|
||||
features.tofile(prefix + '_features.f32')
|
||||
x_in.tofile(prefix + '_x_in.f32')
|
||||
x_out.tofile(prefix + '_x_out.f32')
|
||||
|
||||
def create_feature_net_testvector(prefix, model, num_frames):
|
||||
num_features = model.num_features
|
||||
num_subframes = 4 * num_frames
|
||||
|
||||
input_features = torch.randn((1, num_subframes, num_features))
|
||||
periods = torch.randint(32, 300, (1, num_subframes))
|
||||
numbits = model.numbits_range[0] + torch.rand((1, num_frames, 2)) * (model.numbits_range[1] - model.numbits_range[0])
|
||||
|
||||
|
||||
pembed = model.pitch_embedding(periods)
|
||||
nembed = torch.repeat_interleave(model.numbits_embedding(numbits).flatten(2), 4, dim=1)
|
||||
full_features = torch.cat((input_features, pembed, nembed), dim=-1)
|
||||
|
||||
cf = model.feature_net(full_features)
|
||||
|
||||
input_features.float().numpy().tofile(prefix + "_in_features.f32")
|
||||
periods.numpy().astype(np.int32).tofile(prefix + "_periods.s32")
|
||||
numbits.float().numpy().tofile(prefix + "_numbits.f32")
|
||||
full_features.detach().numpy().tofile(prefix + "_full_features.f32")
|
||||
cf.detach().numpy().tofile(prefix + "_out_features.f32")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_folder, exist_ok=True)
|
||||
|
||||
lace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "lace_checkpoint.pth"), map_location='cpu')
|
||||
nolace_checkpoint = torch.load(os.path.join(args.checkpoint_path, "nolace_checkpoint.pth"), map_location='cpu')
|
||||
|
||||
lace = model_dict['lace'](**lace_checkpoint['setup']['model']['kwargs'])
|
||||
nolace = model_dict['nolace'](**nolace_checkpoint['setup']['model']['kwargs'])
|
||||
|
||||
lace.load_state_dict(lace_checkpoint['state_dict'])
|
||||
nolace.load_state_dict(nolace_checkpoint['state_dict'])
|
||||
|
||||
if args.debug:
|
||||
endoscopy.init(args.output_folder)
|
||||
|
||||
# lace af1, 1 input channel, 1 output channel
|
||||
create_adaconv_testvector(os.path.join(args.output_folder, "lace_af1"), lace.af1, 5, debug=args.debug)
|
||||
|
||||
# nolace af1, 1 input channel, 2 output channels
|
||||
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af1"), nolace.af1, 5, debug=args.debug)
|
||||
|
||||
# nolace af4, 2 input channel, 1 output channels
|
||||
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af4"), nolace.af4, 5, debug=args.debug)
|
||||
|
||||
# nolace af2, 2 input channel, 2 output channels
|
||||
create_adaconv_testvector(os.path.join(args.output_folder, "nolace_af2"), nolace.af2, 5, debug=args.debug)
|
||||
|
||||
# lace cf1
|
||||
create_adacomb_testvector(os.path.join(args.output_folder, "lace_cf1"), lace.cf1, 5, debug=args.debug)
|
||||
|
||||
# nolace tdshape1
|
||||
create_adashape_testvector(os.path.join(args.output_folder, "nolace_tdshape1"), nolace.tdshape1, 5)
|
||||
|
||||
# lace feature net
|
||||
create_feature_net_testvector(os.path.join(args.output_folder, 'lace'), lace, 5)
|
||||
|
||||
if args.debug:
|
||||
endoscopy.close()
|
||||
@@ -0,0 +1,2 @@
|
||||
from .silk_enhancement_set import SilkEnhancementSet
|
||||
from .lpcnet_vocoding_dataset import LPCNetVocodingDataset
|
||||
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" Dataset for LPCNet training """
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
scale = 255.0/32768.0
|
||||
scale_1 = 32768.0/255.0
|
||||
def ulaw2lin(u):
|
||||
u = u - 128
|
||||
s = np.sign(u)
|
||||
u = np.abs(u)
|
||||
return s*scale_1*(np.exp(u/128.*np.log(256))-1)
|
||||
|
||||
|
||||
def lin2ulaw(x):
|
||||
s = np.sign(x)
|
||||
x = np.abs(x)
|
||||
u = (s*(128*np.log(1+scale*x)/np.log(256)))
|
||||
u = np.clip(128 + np.round(u), 0, 255)
|
||||
return u
|
||||
|
||||
|
||||
def run_lpc(signal, lpcs, frame_length=160):
|
||||
num_frames, lpc_order = lpcs.shape
|
||||
|
||||
prediction = np.concatenate(
|
||||
[- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
|
||||
)
|
||||
error = signal[lpc_order :] - prediction
|
||||
|
||||
return prediction, error
|
||||
|
||||
class LPCNetVocodingDataset(Dataset):
|
||||
def __init__(self,
|
||||
path_to_dataset,
|
||||
features=['cepstrum', 'periods', 'pitch_corr'],
|
||||
target='signal',
|
||||
frames_per_sample=100,
|
||||
feature_history=0,
|
||||
feature_lookahead=0,
|
||||
lpc_gamma=1):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# load dataset info
|
||||
self.path_to_dataset = path_to_dataset
|
||||
with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
|
||||
dataset = yaml.load(f, yaml.FullLoader)
|
||||
|
||||
# dataset version
|
||||
self.version = dataset['version']
|
||||
if self.version == 1:
|
||||
self.getitem = self.getitem_v1
|
||||
elif self.version == 2:
|
||||
self.getitem = self.getitem_v2
|
||||
else:
|
||||
raise ValueError(f"dataset version {self.version} unknown")
|
||||
|
||||
# features
|
||||
self.feature_history = feature_history
|
||||
self.feature_lookahead = feature_lookahead
|
||||
self.frame_offset = 2 + self.feature_history
|
||||
self.frames_per_sample = frames_per_sample
|
||||
self.input_features = features
|
||||
self.feature_frame_layout = dataset['feature_frame_layout']
|
||||
self.lpc_gamma = lpc_gamma
|
||||
|
||||
# load feature file
|
||||
self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
|
||||
self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
|
||||
self.feature_frame_length = dataset['feature_frame_length']
|
||||
|
||||
assert len(self.features) % self.feature_frame_length == 0
|
||||
self.features = self.features.reshape((-1, self.feature_frame_length))
|
||||
|
||||
# derive number of samples is dataset
|
||||
self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample
|
||||
|
||||
# signals
|
||||
self.frame_length = dataset['frame_length']
|
||||
self.signal_frame_layout = dataset['signal_frame_layout']
|
||||
self.target = target
|
||||
|
||||
# load signals
|
||||
self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
|
||||
self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
|
||||
self.signal_frame_length = dataset['signal_frame_length']
|
||||
self.signals = self.signals.reshape((-1, self.signal_frame_length))
|
||||
assert len(self.signals) == len(self.features) * self.frame_length
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
|
||||
def getitem_v2(self, index):
|
||||
sample = dict()
|
||||
|
||||
# extract features
|
||||
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||
|
||||
for feature in self.input_features:
|
||||
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||
|
||||
# convert periods
|
||||
if 'periods' in self.input_features:
|
||||
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||
|
||||
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||
|
||||
# last_signal and signal are always expected to be there
|
||||
sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||
sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
|
||||
|
||||
# calculate prediction and error if lpc coefficients present and prediction not given
|
||||
if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
|
||||
# lpc coefficients with one frame lookahead
|
||||
# frame positions (start one frame early for past excitation)
|
||||
frame_start = self.frame_offset + self.frames_per_sample * index - 1
|
||||
frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
|
||||
|
||||
# feature positions
|
||||
lpc_start, lpc_stop = self.feature_frame_layout['lpc']
|
||||
lpc_order = lpc_stop - lpc_start
|
||||
lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
|
||||
|
||||
# LPC weighting
|
||||
lpc_order = lpc_stop - lpc_start
|
||||
weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
|
||||
lpcs = lpcs * weights
|
||||
|
||||
# signal position (lpc_order samples as history)
|
||||
signal_start = frame_start * self.frame_length - lpc_order + 1
|
||||
signal_stop = frame_stop * self.frame_length + 1
|
||||
noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
|
||||
clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
|
||||
|
||||
noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
|
||||
|
||||
# extract signals
|
||||
offset = self.frame_length
|
||||
sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
|
||||
sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
|
||||
# calculate error between real signal and noisy prediction
|
||||
|
||||
|
||||
sample['error'] = sample['signal'] - sample['prediction']
|
||||
|
||||
|
||||
# concatenate features
|
||||
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||
target = torch.FloatTensor(sample[self.target]) / 2**15
|
||||
periods = torch.LongTensor(sample['periods'])
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'target' : target}
|
||||
|
||||
def getitem_v1(self, index):
|
||||
sample = dict()
|
||||
|
||||
# extract features
|
||||
frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
|
||||
frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
|
||||
|
||||
for feature in self.input_features:
|
||||
feature_start, feature_stop = self.feature_frame_layout[feature]
|
||||
sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
|
||||
|
||||
# convert periods
|
||||
if 'periods' in self.input_features:
|
||||
sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
|
||||
|
||||
signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
|
||||
signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
|
||||
|
||||
# last_signal and signal are always expected to be there
|
||||
for signal_name, index in self.signal_frame_layout.items():
|
||||
sample[signal_name] = self.signals[signal_start : signal_stop, index]
|
||||
|
||||
# concatenate features
|
||||
feature_keys = [key for key in self.input_features if not key.startswith("periods")]
|
||||
features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
|
||||
signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
|
||||
target = torch.LongTensor(sample[self.target])
|
||||
periods = torch.LongTensor(sample['periods'])
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_length
|
||||
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
|
||||
from utils.silk_features import silk_feature_factory
|
||||
from utils.pitch import hangover, calculate_acorr_window
|
||||
|
||||
|
||||
class SilkEnhancementSet(Dataset):
|
||||
def __init__(self,
|
||||
path,
|
||||
frames_per_sample=100,
|
||||
no_pitch_value=9,
|
||||
acorr_radius=2,
|
||||
pitch_hangover=8,
|
||||
num_bands_clean_spec=64,
|
||||
num_bands_noisy_spec=18,
|
||||
noisy_spec_scale='opus',
|
||||
noisy_apply_dct=True,
|
||||
add_offset=False,
|
||||
add_double_lag_acorr=False
|
||||
):
|
||||
|
||||
assert frames_per_sample % 4 == 0
|
||||
|
||||
self.frame_size = 80
|
||||
self.frames_per_sample = frames_per_sample
|
||||
self.no_pitch_value = no_pitch_value
|
||||
self.acorr_radius = acorr_radius
|
||||
self.pitch_hangover = pitch_hangover
|
||||
self.num_bands_clean_spec = num_bands_clean_spec
|
||||
self.num_bands_noisy_spec = num_bands_noisy_spec
|
||||
self.noisy_spec_scale = noisy_spec_scale
|
||||
self.add_double_lag_acorr = add_double_lag_acorr
|
||||
|
||||
self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
|
||||
self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
|
||||
self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
|
||||
self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
|
||||
self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
|
||||
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
|
||||
self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
|
||||
self.lpcnet_features = np.from_file(os.path.join(path, 'features_lpcnet.f32'), dtype=np.float32).reshape(-1, 36)
|
||||
|
||||
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
|
||||
|
||||
self.create_features = silk_feature_factory(no_pitch_value,
|
||||
acorr_radius,
|
||||
pitch_hangover,
|
||||
num_bands_clean_spec,
|
||||
num_bands_noisy_spec,
|
||||
noisy_spec_scale,
|
||||
noisy_apply_dct,
|
||||
add_offset,
|
||||
add_double_lag_acorr)
|
||||
|
||||
self.history_len = 700 if add_double_lag_acorr else 350
|
||||
# discard some frames to have enough signal history
|
||||
self.skip_frames = 4 * ((self.history_len + 319) // 320 + 2)
|
||||
|
||||
num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames
|
||||
|
||||
self.len = num_frames // frames_per_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
frame_start = self.frames_per_sample * index + self.skip_frames
|
||||
frame_stop = frame_start + self.frames_per_sample
|
||||
|
||||
signal_start = frame_start * self.frame_size - self.skip
|
||||
signal_stop = frame_stop * self.frame_size - self.skip
|
||||
|
||||
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||
|
||||
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
|
||||
|
||||
features, periods = self.create_features(
|
||||
coded_signal,
|
||||
coded_signal_history,
|
||||
self.lpcs[frame_start : frame_stop],
|
||||
self.gains[frame_start : frame_stop],
|
||||
self.ltps[frame_start : frame_stop],
|
||||
self.periods[frame_start : frame_stop],
|
||||
self.offsets[frame_start : frame_stop]
|
||||
)
|
||||
|
||||
lpcnet_features = self.lpcnet_features[frame_start // 2 : frame_stop // 2, :20]
|
||||
|
||||
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
||||
num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
||||
|
||||
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
|
||||
|
||||
return {
|
||||
'silk_features' : features,
|
||||
'periods' : periods.astype(np.int64),
|
||||
'numbits' : numbits.astype(np.float32),
|
||||
'lpcnet_features' : lpcnet_features
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
|
||||
from utils.silk_features import silk_feature_factory
|
||||
from utils.pitch import hangover, calculate_acorr_window
|
||||
|
||||
|
||||
class SilkEnhancementSet(Dataset):
|
||||
def __init__(self,
|
||||
path,
|
||||
frames_per_sample=100,
|
||||
no_pitch_value=256,
|
||||
preemph=0.85,
|
||||
skip=91,
|
||||
acorr_radius=2,
|
||||
pitch_hangover=8,
|
||||
num_bands_clean_spec=64,
|
||||
num_bands_noisy_spec=18,
|
||||
noisy_spec_scale='opus',
|
||||
noisy_apply_dct=True,
|
||||
add_double_lag_acorr=False,
|
||||
):
|
||||
|
||||
assert frames_per_sample % 4 == 0
|
||||
|
||||
self.frame_size = 80
|
||||
self.frames_per_sample = frames_per_sample
|
||||
self.no_pitch_value = no_pitch_value
|
||||
self.preemph = preemph
|
||||
self.skip = skip
|
||||
self.acorr_radius = acorr_radius
|
||||
self.pitch_hangover = pitch_hangover
|
||||
self.num_bands_clean_spec = num_bands_clean_spec
|
||||
self.num_bands_noisy_spec = num_bands_noisy_spec
|
||||
self.noisy_spec_scale = noisy_spec_scale
|
||||
self.add_double_lag_acorr = add_double_lag_acorr
|
||||
|
||||
self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
|
||||
self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
|
||||
self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
|
||||
self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
|
||||
self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
|
||||
self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
|
||||
|
||||
self.clean_signal_hp = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16)
|
||||
self.clean_signal = np.fromfile(os.path.join(path, 'clean.s16'), dtype=np.int16)
|
||||
self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
|
||||
|
||||
self.create_features = silk_feature_factory(no_pitch_value,
|
||||
acorr_radius,
|
||||
pitch_hangover,
|
||||
num_bands_clean_spec,
|
||||
num_bands_noisy_spec,
|
||||
noisy_spec_scale,
|
||||
noisy_apply_dct,
|
||||
add_double_lag_acorr)
|
||||
|
||||
self.history_len = 700 if add_double_lag_acorr else 350
|
||||
# discard some frames to have enough signal history
|
||||
self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2)
|
||||
|
||||
num_frames = self.clean_signal_hp.shape[0] // 80 - self.skip_frames
|
||||
|
||||
self.len = num_frames // frames_per_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
frame_start = self.frames_per_sample * index + self.skip_frames
|
||||
frame_stop = frame_start + self.frames_per_sample
|
||||
|
||||
signal_start = frame_start * self.frame_size - self.skip
|
||||
signal_stop = frame_stop * self.frame_size - self.skip
|
||||
|
||||
clean_signal_hp = self.clean_signal_hp[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||
clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||
coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
|
||||
|
||||
coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
|
||||
|
||||
features, periods = self.create_features(
|
||||
coded_signal,
|
||||
coded_signal_history,
|
||||
self.lpcs[frame_start : frame_stop],
|
||||
self.gains[frame_start : frame_stop],
|
||||
self.ltps[frame_start : frame_stop],
|
||||
self.periods[frame_start : frame_stop]
|
||||
)
|
||||
|
||||
if self.preemph > 0:
|
||||
clean_signal[1:] -= self.preemph * clean_signal[: -1]
|
||||
clean_signal_hp[1:] -= self.preemph * clean_signal_hp[: -1]
|
||||
coded_signal[1:] -= self.preemph * coded_signal[: -1]
|
||||
|
||||
num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
||||
num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
|
||||
|
||||
numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
|
||||
|
||||
return {
|
||||
'features' : features,
|
||||
'periods' : periods.astype(np.int64),
|
||||
'target_orig' : clean_signal.astype(np.float32),
|
||||
'target' : clean_signal_hp.astype(np.float32),
|
||||
'signals' : coded_signal.reshape(-1, 1).astype(np.float32),
|
||||
'numbits' : numbits.astype(np.float32)
|
||||
}
|
||||
103
managed_components/78__esp-opus/dnn/torch/osce/engine/engine.py
Normal file
103
managed_components/78__esp-opus/dnn/torch/osce/engine/engine.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
|
||||
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# set gradients to zero
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
|
||||
|
||||
# calculate loss
|
||||
if isinstance(output, list):
|
||||
loss = torch.zeros(1, device=device)
|
||||
for y in output:
|
||||
loss = loss + criterion(target, y.squeeze(1))
|
||||
loss = loss / len(output)
|
||||
else:
|
||||
loss = criterion(target, output.squeeze(1))
|
||||
|
||||
# calculate gradients
|
||||
loss.backward()
|
||||
|
||||
# update weights
|
||||
optimizer.step()
|
||||
|
||||
# update learning rate
|
||||
scheduler.step()
|
||||
|
||||
# sparsification
|
||||
if hasattr(model, 'sparsifier'):
|
||||
model.sparsifier()
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
|
||||
def evaluate(model, criterion, dataloader, device, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits'])
|
||||
|
||||
# calculate loss
|
||||
loss = criterion(target, output.squeeze(1))
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
@@ -0,0 +1,101 @@
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
|
||||
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
# set gradients to zero
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['features'], batch['periods'])
|
||||
|
||||
# calculate loss
|
||||
if isinstance(output, list):
|
||||
loss = torch.zeros(1, device=device)
|
||||
for y in output:
|
||||
loss = loss + criterion(target, y.squeeze(1))
|
||||
loss = loss / len(output)
|
||||
else:
|
||||
loss = criterion(target, output.squeeze(1))
|
||||
|
||||
# calculate gradients
|
||||
loss.backward()
|
||||
|
||||
# update weights
|
||||
optimizer.step()
|
||||
|
||||
# update learning rate
|
||||
scheduler.step()
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
|
||||
def evaluate(model, criterion, dataloader, device, log_interval=10):
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
running_loss = 0
|
||||
previous_running_loss = 0
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
|
||||
|
||||
for i, batch in enumerate(tepoch):
|
||||
|
||||
|
||||
|
||||
# push batch to device
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device)
|
||||
|
||||
target = batch['target']
|
||||
|
||||
# calculate model output
|
||||
output = model(batch['features'], batch['periods'])
|
||||
|
||||
# calculate loss
|
||||
loss = criterion(target, output.squeeze(1))
|
||||
|
||||
# update running loss
|
||||
running_loss += float(loss.cpu())
|
||||
|
||||
# update status bar
|
||||
if i % log_interval == 0:
|
||||
tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
|
||||
previous_running_loss = running_loss
|
||||
|
||||
|
||||
running_loss /= len(dataloader)
|
||||
|
||||
return running_loss
|
||||
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import hashlib
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
|
||||
|
||||
import torch
|
||||
import wexchange.torch
|
||||
from wexchange.torch import dump_torch_weights
|
||||
from models import model_dict
|
||||
|
||||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
from utils.layers.td_shaper import TDShaper
|
||||
from utils.misc import remove_all_weight_norm
|
||||
from wexchange.torch import dump_torch_weights
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint')
|
||||
parser.add_argument('output_dir', type=str, help='output folder')
|
||||
parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
|
||||
|
||||
sparse_default=False
|
||||
schedules = {
|
||||
'nolace': [
|
||||
('pitch_embedding', dict()),
|
||||
('feature_net.conv1', dict()),
|
||||
('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
|
||||
('cf1', dict(quantize=True, scale=None)),
|
||||
('cf2', dict(quantize=True, scale=None)),
|
||||
('af1', dict(quantize=True, scale=None)),
|
||||
('tdshape1', dict(quantize=True, scale=None)),
|
||||
('tdshape2', dict(quantize=True, scale=None)),
|
||||
('tdshape3', dict(quantize=True, scale=None)),
|
||||
('af2', dict(quantize=True, scale=None)),
|
||||
('af3', dict(quantize=True, scale=None)),
|
||||
('af4', dict(quantize=True, scale=None)),
|
||||
('post_cf1', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('post_cf2', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('post_af1', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('post_af2', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('post_af3', dict(quantize=True, scale=None, sparse=sparse_default))
|
||||
],
|
||||
'lace' : [
|
||||
('pitch_embedding', dict()),
|
||||
('feature_net.conv1', dict()),
|
||||
('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
|
||||
('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
|
||||
('cf1', dict(quantize=True, scale=None)),
|
||||
('cf2', dict(quantize=True, scale=None)),
|
||||
('af1', dict(quantize=True, scale=None))
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# auxiliary functions
|
||||
def sha1(filename):
|
||||
BUF_SIZE = 65536
|
||||
sha1 = hashlib.sha1()
|
||||
|
||||
with open(filename, 'rb') as f:
|
||||
while True:
|
||||
data = f.read(BUF_SIZE)
|
||||
if not data:
|
||||
break
|
||||
sha1.update(data)
|
||||
|
||||
return sha1.hexdigest()
|
||||
|
||||
def osce_dump_generic(writer, name, module):
|
||||
if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv1d) \
|
||||
or isinstance(module, torch.nn.ConvTranspose1d) or isinstance(module, torch.nn.Embedding) \
|
||||
or isinstance(module, LimitedAdaptiveConv1d) or isinstance(module, LimitedAdaptiveComb1d) \
|
||||
or isinstance(module, TDShaper) or isinstance(module, torch.nn.GRU):
|
||||
dump_torch_weights(writer, module, name=name, verbose=True)
|
||||
else:
|
||||
for child_name, child in module.named_children():
|
||||
osce_dump_generic(writer, (name + "_" + child_name).replace("feature_net", "fnet"), child)
|
||||
|
||||
|
||||
def export_name(name):
|
||||
name = name.replace('.', '_')
|
||||
name = name.replace('feature_net', 'fnet')
|
||||
return name
|
||||
|
||||
def osce_scheduled_dump(writer, prefix, model, schedule):
|
||||
if not prefix.endswith('_'):
|
||||
prefix += '_'
|
||||
|
||||
for name, kwargs in schedule:
|
||||
dump_torch_weights(writer, model.get_submodule(name), prefix + export_name(name), **kwargs, verbose=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint_path = args.checkpoint
|
||||
outdir = args.output_dir
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
# dump message
|
||||
message = f"Auto generated from checkpoint {os.path.basename(checkpoint_path)} (sha1: {sha1(checkpoint_path)})"
|
||||
|
||||
# create model and load weights
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
remove_all_weight_norm(model, verbose=True)
|
||||
|
||||
# CWriter
|
||||
model_name = checkpoint['setup']['model']['name']
|
||||
cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper() + 'Layers', add_typedef=True)
|
||||
|
||||
# Add custom includes and global parameters
|
||||
cwriter.header.write(f'''
|
||||
#define {model_name.upper()}_PREEMPH {model.preemph}f
|
||||
#define {model_name.upper()}_FRAME_SIZE {model.FRAME_SIZE}
|
||||
#define {model_name.upper()}_OVERLAP_SIZE 40
|
||||
#define {model_name.upper()}_NUM_FEATURES {model.num_features}
|
||||
#define {model_name.upper()}_PITCH_MAX {model.pitch_max}
|
||||
#define {model_name.upper()}_PITCH_EMBEDDING_DIM {model.pitch_embedding_dim}
|
||||
#define {model_name.upper()}_NUMBITS_RANGE_LOW {model.numbits_range[0]}
|
||||
#define {model_name.upper()}_NUMBITS_RANGE_HIGH {model.numbits_range[1]}
|
||||
#define {model_name.upper()}_NUMBITS_EMBEDDING_DIM {model.numbits_embedding_dim}
|
||||
#define {model_name.upper()}_COND_DIM {model.cond_dim}
|
||||
#define {model_name.upper()}_HIDDEN_FEATURE_DIM {model.hidden_feature_dim}
|
||||
''')
|
||||
|
||||
for i, s in enumerate(model.numbits_embedding.scale_factors):
|
||||
cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}f\n")
|
||||
|
||||
# dump layers
|
||||
if model_name in schedules and args.quantize:
|
||||
osce_scheduled_dump(cwriter, model_name, model, schedules[model_name])
|
||||
else:
|
||||
osce_dump_generic(cwriter, model_name, model)
|
||||
|
||||
cwriter.close()
|
||||
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
"""STFT-based Loss modules."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
|
||||
|
||||
def get_window(win_name, win_length, *args, **kwargs):
|
||||
window_dict = {
|
||||
'bartlett_window' : torch.bartlett_window,
|
||||
'blackman_window' : torch.blackman_window,
|
||||
'hamming_window' : torch.hamming_window,
|
||||
'hann_window' : torch.hann_window,
|
||||
'kaiser_window' : torch.kaiser_window
|
||||
}
|
||||
|
||||
if not win_name in window_dict:
|
||||
raise ValueError()
|
||||
|
||||
return window_dict[win_name](win_length, *args, **kwargs)
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
"""
|
||||
|
||||
win = get_window(window, win_length).to(x.device)
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
|
||||
|
||||
|
||||
return torch.clamp(torch.abs(x_stft), min=1e-7)
|
||||
|
||||
def spectral_convergence_loss(Y_true, Y_pred):
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
|
||||
|
||||
|
||||
def log_magnitude_loss(Y_true, Y_pred):
|
||||
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
|
||||
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
|
||||
|
||||
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
|
||||
|
||||
def spectral_xcorr_loss(Y_true, Y_pred):
|
||||
Y_true = Y_true.abs()
|
||||
Y_pred = Y_pred.abs()
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return 1 - xcorr.mean()
|
||||
|
||||
|
||||
|
||||
class MRLogMelLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
fs=16000,
|
||||
n_mels=18
|
||||
):
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.mel_specs = []
|
||||
for fft_size in fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
|
||||
n_mels = self.n_mels
|
||||
if fft_size < 128:
|
||||
n_mels //= 2
|
||||
|
||||
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
|
||||
|
||||
for i, mel_spec in enumerate(self.mel_specs):
|
||||
self.add_module(f'mel_spec_{i+1}', mel_spec)
|
||||
|
||||
def forward(self, y_true, y_pred):
|
||||
|
||||
loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for mel_spec in self.mel_specs:
|
||||
Y_true = mel_spec(y_true)
|
||||
Y_pred = mel_spec(y_pred)
|
||||
loss = loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
loss = loss / len(self.mel_specs)
|
||||
|
||||
return loss
|
||||
|
||||
def create_weight_matrix(num_bins, bins_per_band=10):
|
||||
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
|
||||
|
||||
r0 = bins_per_band // 2
|
||||
r1 = bins_per_band - r0
|
||||
|
||||
for i in range(num_bins):
|
||||
i0 = max(i - r0, 0)
|
||||
j0 = min(i + r1, num_bins)
|
||||
|
||||
m[i, i0: j0] += 1
|
||||
|
||||
if i < r0:
|
||||
m[i, :r0 - i] += 1
|
||||
|
||||
if i > num_bins - r1:
|
||||
m[i, num_bins - r1 - i:] += 1
|
||||
|
||||
return m / bins_per_band
|
||||
|
||||
def weighted_spectral_convergence(Y_true, Y_pred, w):
|
||||
|
||||
# calculate sfm based weights
|
||||
logY = torch.log(torch.abs(Y_true) + 1e-9)
|
||||
Y = torch.abs(Y_true)
|
||||
|
||||
avg_logY = torch.matmul(logY.transpose(1, 2), w)
|
||||
avg_Y = torch.matmul(Y.transpose(1, 2), w)
|
||||
|
||||
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
|
||||
|
||||
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
|
||||
|
||||
loss = torch.mean(
|
||||
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
|
||||
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def gen_filterbank(N, Fs=16000):
|
||||
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
|
||||
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
|
||||
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
|
||||
ERB_N = 24.7 + .108*in_freq
|
||||
delta = np.abs(in_freq-out_freq)/ERB_N
|
||||
center = (delta<.5).astype('float32')
|
||||
R = -12*center*delta**2 + (1-center)*(3-12*delta)
|
||||
RE = 10.**(R/10.)
|
||||
norm = np.sum(RE, axis=1)
|
||||
RE = RE/norm[:, np.newaxis]
|
||||
return torch.from_numpy(RE)
|
||||
|
||||
def smooth_log_mag(Y_true, Y_pred, filterbank):
|
||||
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
|
||||
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
|
||||
|
||||
loss = torch.abs(
|
||||
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
|
||||
)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
class MRSTFTLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[2048, 1024, 512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
window='hann_window',
|
||||
fs=16000,
|
||||
log_mag_weight=1,
|
||||
sc_weight=0,
|
||||
wsc_weight=0,
|
||||
smooth_log_mag_weight=0,
|
||||
sxcorr_weight=0):
|
||||
super().__init__()
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.window = window
|
||||
self.log_mag_weight = log_mag_weight
|
||||
self.sc_weight = sc_weight
|
||||
self.wsc_weight = wsc_weight
|
||||
self.smooth_log_mag_weight = smooth_log_mag_weight
|
||||
self.sxcorr_weight = sxcorr_weight
|
||||
self.fs = fs
|
||||
|
||||
# weights for SFM weighted spectral convergence loss
|
||||
self.wsc_weights = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
width = min(11, int(1000 * fft_size / self.fs + .5))
|
||||
width += width % 2
|
||||
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
|
||||
create_weight_matrix(fft_size // 2 + 1, width),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
# filterbanks for smooth log magnitude loss
|
||||
self.filterbanks = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
|
||||
gen_filterbank(fft_size//2),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def __call__(self, y_true, y_pred):
|
||||
|
||||
|
||||
lm_loss = torch.zeros(1, device=y_true.device)
|
||||
sc_loss = torch.zeros(1, device=y_true.device)
|
||||
wsc_loss = torch.zeros(1, device=y_true.device)
|
||||
slm_loss = torch.zeros(1, device=y_true.device)
|
||||
sxcorr_loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for fft_size in self.fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
win_size = fft_size
|
||||
|
||||
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
|
||||
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
|
||||
|
||||
if self.log_mag_weight > 0:
|
||||
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
if self.sc_weight > 0:
|
||||
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
|
||||
|
||||
if self.wsc_weight > 0:
|
||||
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
|
||||
|
||||
if self.smooth_log_mag_weight > 0:
|
||||
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
|
||||
|
||||
if self.sxcorr_weight > 0:
|
||||
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
|
||||
|
||||
|
||||
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
|
||||
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
|
||||
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
|
||||
|
||||
return total_loss
|
||||
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
import scipy.signal
|
||||
|
||||
|
||||
from utils.layers.fir import FIR
|
||||
|
||||
class TDLowpass(torch.nn.Module):
|
||||
def __init__(self, numtaps, cutoff, power=2):
|
||||
super().__init__()
|
||||
|
||||
self.b = scipy.signal.firwin(numtaps, cutoff)
|
||||
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
|
||||
self.power = power
|
||||
|
||||
def forward(self, y_true, y_pred):
|
||||
|
||||
assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
|
||||
|
||||
diff = y_true - y_pred
|
||||
diff_lp = torch.nn.functional.conv1d(diff, self.weight)
|
||||
|
||||
loss = torch.mean(torch.abs(diff_lp ** self.power))
|
||||
|
||||
return loss, diff_lp
|
||||
|
||||
def get_freqz(self):
|
||||
freq, response = scipy.signal.freqz(self.b)
|
||||
|
||||
return freq, response
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
from utils.templates import setup_dict
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('name', type=str, help='name of default setup file')
|
||||
parser.add_argument('--model', choices=['lace', 'nolace', 'lavoce'], help='model name', default='lace')
|
||||
parser.add_argument('--adversarial', action='store_true', help='setup for adversarial training')
|
||||
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
key = args.model + "_adv" if args.adversarial else args.model
|
||||
|
||||
try:
|
||||
setup = setup_dict[key]
|
||||
except KeyError:
|
||||
print("setup not found, adversarial training possibly not specified for model")
|
||||
sys.exit(1)
|
||||
|
||||
# update dataset if given
|
||||
if type(args.path2dataset) != type(None):
|
||||
setup['dataset'] = args.path2dataset
|
||||
|
||||
name = args.name
|
||||
if not name.endswith('.yml'):
|
||||
name += '.yml'
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open(name, 'w') as f:
|
||||
f.write(yaml.dump(setup))
|
||||
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
from .lace import LACE
|
||||
from .no_lace import NoLACE
|
||||
from .lavoce import LaVoce
|
||||
from .lavoce_400 import LaVoce400
|
||||
from .fd_discriminator import TFDMultiResolutionDiscriminator as FDMResDisc
|
||||
|
||||
model_dict = {
|
||||
'lace': LACE,
|
||||
'nolace': NoLACE,
|
||||
'lavoce': LaVoce,
|
||||
'lavoce400': LaVoce400,
|
||||
'fdmresdisc': FDMResDisc,
|
||||
}
|
||||
@@ -0,0 +1,974 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import math as m
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
import torchaudio
|
||||
|
||||
from utils.spec import gen_filterbank
|
||||
|
||||
# auxiliary functions
|
||||
|
||||
def remove_all_weight_norms(module):
|
||||
for m in module.modules():
|
||||
if hasattr(m, 'weight_v'):
|
||||
nn.utils.remove_weight_norm(m)
|
||||
|
||||
|
||||
def create_smoothing_kernel(h, w, gamma=1.5):
|
||||
|
||||
ch = h / 2 - 0.5
|
||||
cw = w / 2 - 0.5
|
||||
|
||||
sh = gamma * ch
|
||||
sw = gamma * cw
|
||||
|
||||
vx = ((torch.arange(h) - ch) / sh) ** 2
|
||||
vy = ((torch.arange(w) - cw) / sw) ** 2
|
||||
vals = vx.view(-1, 1) + vy.view(1, -1)
|
||||
kernel = torch.exp(- vals)
|
||||
kernel = kernel / kernel.sum()
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def create_kernel(h, w, sh, sw):
|
||||
# proto kernel gives disjoint partition of 1
|
||||
proto_kernel = torch.ones((sh, sw))
|
||||
|
||||
# create smoothing kernel eta
|
||||
h_eta, w_eta = h - sh + 1, w - sw + 1
|
||||
assert h_eta > 0 and w_eta > 0
|
||||
eta = create_smoothing_kernel(h_eta, w_eta).view(1, 1, h_eta, w_eta)
|
||||
|
||||
kernel0 = F.pad(proto_kernel, [w_eta - 1, w_eta - 1, h_eta - 1, h_eta - 1]).unsqueeze(0).unsqueeze(0)
|
||||
kernel = F.conv2d(kernel0, eta)
|
||||
|
||||
return kernel
|
||||
|
||||
# positional embeddings
|
||||
class FrequencyPositionalEmbedding(nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
N = x.size(2)
|
||||
args = torch.arange(0, N, dtype=x.dtype, device=x.device) * torch.pi * 2 / N
|
||||
cos = torch.cos(args).reshape(1, 1, -1, 1)
|
||||
sin = torch.sin(args).reshape(1, 1, -1, 1)
|
||||
zeros = torch.zeros_like(x[:, 0:1, :, :])
|
||||
|
||||
y = torch.cat((x, zeros + sin, zeros + cos), dim=1)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class PositionalEmbedding2D(nn.Module):
|
||||
def __init__(self, d=5):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.d = d
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
N = x.size(2)
|
||||
M = x.size(3)
|
||||
|
||||
h_args = torch.arange(0, N, dtype=x.dtype, device=x.device).reshape(1, 1, -1, 1)
|
||||
w_args = torch.arange(0, M, dtype=x.dtype, device=x.device).reshape(1, 1, 1, -1)
|
||||
coeffs = (10000 ** (-2 * torch.arange(0, self.d, dtype=x.dtype, device=x.device) / self.d)).reshape(1, -1, 1, 1)
|
||||
|
||||
h_sin = torch.sin(coeffs * h_args)
|
||||
h_cos = torch.sin(coeffs * h_args)
|
||||
w_sin = torch.sin(coeffs * w_args)
|
||||
w_cos = torch.sin(coeffs * w_args)
|
||||
|
||||
zeros = torch.zeros_like(x[:, 0:1, :, :])
|
||||
|
||||
y = torch.cat((x, zeros + h_sin, zeros + h_cos, zeros + w_sin, zeros + w_cos), dim=1)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
# spectral discriminator base class
|
||||
class SpecDiscriminatorBase(nn.Module):
|
||||
RECEPTIVE_FIELD_MAX_WIDTH=10000
|
||||
def __init__(self,
|
||||
layers,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7000],
|
||||
noise_gain=1e-3,
|
||||
fmap_start_index=0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.resolution = resolution
|
||||
self.fs = fs
|
||||
self.noise_gain = noise_gain
|
||||
self.fmap_start_index = fmap_start_index
|
||||
|
||||
if fmap_start_index >= len(layers):
|
||||
raise ValueError(f'fmap_start_index is larger than number of layers')
|
||||
|
||||
# filter bank for noise shaping
|
||||
n_fft = resolution[0]
|
||||
|
||||
self.filterbank = nn.Parameter(
|
||||
gen_filterbank(n_fft // 2, fs, keep_size=True),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
# roi bins
|
||||
f_step = fs / n_fft
|
||||
self.start_bin = int(m.ceil(freq_roi[0] / f_step - 0.01))
|
||||
self.stop_bin = min(int(m.floor(freq_roi[1] / f_step + 0.01)), n_fft//2 + 1)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# determine receptive field size, offsets and strides
|
||||
|
||||
hw = 1000
|
||||
while True:
|
||||
x = torch.zeros((1, hw, hw))
|
||||
with torch.no_grad():
|
||||
y = self.run_layer_stack(x)[-1]
|
||||
|
||||
pos0 = [y.size(-2) // 2, y.size(-1) // 2]
|
||||
pos1 = [t + 1 for t in pos0]
|
||||
|
||||
hs0, ws0 = self._receptive_field((hw, hw), pos0)
|
||||
hs1, ws1 = self._receptive_field((hw, hw), pos1)
|
||||
|
||||
h0 = hs0[1] - hs0[0] + 1
|
||||
h1 = hs1[1] - hs1[0] + 1
|
||||
w0 = ws0[1] - ws0[0] + 1
|
||||
w1 = ws1[1] - ws1[0] + 1
|
||||
|
||||
if h0 != h1 or w0 != w1:
|
||||
hw = 2 * hw
|
||||
else:
|
||||
|
||||
# strides
|
||||
sh = hs1[0] - hs0[0]
|
||||
sw = ws1[0] - ws0[0]
|
||||
|
||||
if sh == 0 or sw == 0: continue
|
||||
|
||||
# offsets
|
||||
oh = hs0[0] - sh * pos0[0]
|
||||
ow = ws0[0] - sw * pos0[1]
|
||||
|
||||
# overlap factor
|
||||
overlap = w0 / sw + h0 / sh
|
||||
|
||||
#print(f"{w0=} {h0=} {sw=} {sh=} {overlap=}")
|
||||
self.receptive_field_params = {'width': [sw, ow, w0], 'height': [sh, oh, h0], 'overlap': overlap}
|
||||
|
||||
break
|
||||
|
||||
if hw > self.RECEPTIVE_FIELD_MAX_WIDTH:
|
||||
print("warning: exceeded max size while trying to determine receptive field")
|
||||
|
||||
# create transposed convolutional kernel
|
||||
#self.tconv_kernel = nn.Parameter(create_kernel(h0, w0, sw, sw), requires_grad=False)
|
||||
|
||||
def run_layer_stack(self, spec):
|
||||
|
||||
output = []
|
||||
|
||||
x = spec.unsqueeze(1)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
output.append(x)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, x):
|
||||
""" returns array with feature maps and final score at index -1 """
|
||||
|
||||
output = []
|
||||
|
||||
x = self.spectrogram(x)
|
||||
|
||||
output = self.run_layer_stack(x)
|
||||
|
||||
return output[self.fmap_start_index:]
|
||||
|
||||
def receptive_field(self, output_pos):
|
||||
|
||||
if self.receptive_field_params is not None:
|
||||
s, o, h = self.receptive_field_params['height']
|
||||
h_min = output_pos[0] * s + o + self.start_bin
|
||||
h_max = h_min + h
|
||||
h_min = max(h_min, self.start_bin)
|
||||
h_max = min(h_max, self.stop_bin)
|
||||
|
||||
s, o, w = self.receptive_field_params['width']
|
||||
w_min = output_pos[1] * s + o
|
||||
w_max = w_min + w
|
||||
|
||||
return (h_min, h_max), (w_min, w_max)
|
||||
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
def _receptive_field(self, input_dims, output_pos):
|
||||
""" determines receptive field probabilistically via autograd (slow) """
|
||||
|
||||
x = torch.randn((1,) + input_dims, requires_grad=True)
|
||||
|
||||
# run input through layers
|
||||
y = self.run_layer_stack(x)[-1]
|
||||
b, c, h, w = y.shape
|
||||
|
||||
if output_pos[0] >= h or output_pos[1] >= w:
|
||||
raise ValueError("position out of range")
|
||||
|
||||
mask = torch.zeros((b, c, h, w))
|
||||
mask[0, 0, output_pos[0], output_pos[1]] = 1
|
||||
|
||||
(mask * y).sum().backward()
|
||||
|
||||
hs, ws = torch.nonzero(x.grad[0], as_tuple=True)
|
||||
|
||||
h_min, h_max = hs.min().item(), hs.max().item()
|
||||
w_min, w_max = ws.min().item(), ws.max().item()
|
||||
|
||||
return [h_min, h_max], [w_min, w_max]
|
||||
|
||||
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
|
||||
def spectrogram(self, x):
|
||||
n_fft, hop_length, win_length = self.resolution
|
||||
x = x.squeeze(1)
|
||||
window = getattr(torch, 'hann_window')(win_length).to(x.device)
|
||||
|
||||
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
|
||||
window=window, return_complex=True) #[B, F, T]
|
||||
x = torch.abs(x)
|
||||
|
||||
# noise floor following spectral envelope
|
||||
smoothed_x = torch.matmul(self.filterbank, x)
|
||||
noise = torch.randn_like(x) * smoothed_x * self.noise_gain
|
||||
x = x + noise
|
||||
|
||||
# frequency ROI
|
||||
x = x[:, self.start_bin : self.stop_bin + 1, ...]
|
||||
|
||||
return torchaudio.functional.amplitude_to_DB(x,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)#torch.sqrt(x)
|
||||
|
||||
def grad_map(self, x):
|
||||
self.zero_grad()
|
||||
|
||||
n_fft, hop_length, win_length = self.resolution
|
||||
|
||||
window = getattr(torch, 'hann_window')(win_length).to(x.device)
|
||||
|
||||
y = torch.stft(x.squeeze(1), n_fft=n_fft, hop_length=hop_length, win_length=win_length,
|
||||
window=window, return_complex=True) #[B, F, T]
|
||||
y = torch.abs(y)
|
||||
|
||||
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
|
||||
|
||||
specgram.requires_grad = True
|
||||
specgram.retain_grad()
|
||||
|
||||
if specgram.grad is not None:
|
||||
specgram.grad.zero_()
|
||||
|
||||
y = specgram[:, self.start_bin : self.stop_bin + 1, ...]
|
||||
|
||||
scores = self.run_layer_stack(y)[-1]
|
||||
|
||||
loss = torch.mean((1 - scores) ** 2)
|
||||
loss.backward()
|
||||
|
||||
return specgram.data[0], torch.abs(specgram.grad)[0]
|
||||
|
||||
def relevance_map(self, x):
|
||||
|
||||
n_fft, hop_length, win_length = self.resolution
|
||||
y = x.view(-1)
|
||||
window = getattr(torch, 'hann_window')(win_length).to(x.device)
|
||||
|
||||
y = torch.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
|
||||
window=window, return_complex=True) #[B, F, T]
|
||||
y = torch.abs(y)
|
||||
|
||||
specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
|
||||
|
||||
|
||||
scores = self.forward(x)[-1]
|
||||
|
||||
sh, _, h = self.receptive_field_params['height']
|
||||
sw, _, w = self.receptive_field_params['width']
|
||||
kernel = create_kernel(h, w, sh, sw).float().to(scores.device)
|
||||
with torch.no_grad():
|
||||
pad_w = (w + sw - 1) // sw
|
||||
pad_h = (h + sh - 1) // sh
|
||||
padded_scores = F.pad(scores, (pad_w, pad_w, pad_h, pad_h), mode='replicate')
|
||||
# CAVE: padding should be derived from offsets
|
||||
rv = F.conv_transpose2d(padded_scores, kernel, bias=None, stride=(sh, sw), padding=(h//2, w//2))
|
||||
rv = rv[..., pad_h * sh : - pad_h * sh, pad_w * sw : -pad_w * sw]
|
||||
|
||||
relevance = torch.zeros_like(specgram)
|
||||
relevance[..., self.start_bin : self.start_bin + rv.size(-2), : rv.size(-1)] = rv
|
||||
|
||||
|
||||
return specgram, relevance
|
||||
|
||||
|
||||
def lrp(self, x, eps=1e-9, label='both', threshold=0.5, low=None, high=None, verbose=False):
|
||||
""" layer-wise relevance propagation (https://git.tu-berlin.de/gmontavon/lrp-tutorial) """
|
||||
|
||||
# ToDo: this code is highly unsafe as it assumes that layers are nn.Sequential with suitable activations
|
||||
|
||||
def newconv2d(layer,g):
|
||||
|
||||
new_layer = nn.Conv2d(layer.in_channels,
|
||||
layer.out_channels,
|
||||
layer.kernel_size,
|
||||
stride=layer.stride,
|
||||
padding=layer.padding,
|
||||
dilation=layer.dilation,
|
||||
groups=layer.groups)
|
||||
|
||||
try: new_layer.weight = nn.Parameter(g(layer.weight.data.clone()))
|
||||
except AttributeError: pass
|
||||
|
||||
try: new_layer.bias = nn.Parameter(g(layer.bias.data.clone()))
|
||||
except AttributeError: pass
|
||||
|
||||
return new_layer
|
||||
|
||||
bounds = {
|
||||
64: [-85.82449722290039, 2.1755014657974243],
|
||||
128: [-84.49211349487305, 3.5078893899917607],
|
||||
256: [-80.33127822875977, 7.6687201976776125],
|
||||
512: [-73.79328079223633, 14.20672025680542],
|
||||
1024: [-67.59239501953125, 20.40760498046875],
|
||||
2048: [-62.31902580261231, 25.680974197387698],
|
||||
}
|
||||
|
||||
nfft = self.resolution[0]
|
||||
if low is None: low = bounds[nfft][0]
|
||||
if high is None: high = bounds[nfft][1]
|
||||
|
||||
remove_all_weight_norms(self)
|
||||
|
||||
for p in self.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad.zero_()
|
||||
|
||||
num_layers = len(self.layers)
|
||||
X = self.spectrogram(x). detach()
|
||||
|
||||
|
||||
# forward pass
|
||||
A = [X.unsqueeze(1)] + [None] * len(self.layers)
|
||||
|
||||
for i in range(num_layers - 1):
|
||||
A[i + 1] = self.layers[i](A[i])
|
||||
|
||||
# initial relevance is last layer without activation
|
||||
r = A[-2]
|
||||
last_layer_rs = [r]
|
||||
layer = self.layers[-1]
|
||||
for sublayer in list(layer)[:-1]:
|
||||
r = sublayer(r)
|
||||
last_layer_rs.append(r)
|
||||
|
||||
|
||||
mask = torch.zeros_like(r)
|
||||
mask.requires_grad_(False)
|
||||
if verbose:
|
||||
print(r.min(), r.max())
|
||||
if label in {'both', 'fake'}:
|
||||
mask[r < -threshold] = 1
|
||||
if label in {'both', 'real'}:
|
||||
mask[r > threshold] = 1
|
||||
r = r * mask
|
||||
|
||||
# backward pass
|
||||
R = [None] * num_layers + [r]
|
||||
|
||||
for l in range(1, num_layers)[::-1]:
|
||||
A[l] = (A[l]).data.requires_grad_(True)
|
||||
|
||||
layer = nn.Sequential(*(list(self.layers[l])[:-1]))
|
||||
z = layer(A[l]) + eps
|
||||
s = (R[l+1] / z).data
|
||||
(z*s).sum().backward()
|
||||
c = A[l].grad
|
||||
R[l] = (A[l] * c).data
|
||||
|
||||
# first layer
|
||||
A[0] = (A[0].data).requires_grad_(True)
|
||||
|
||||
Xl = (torch.zeros_like(A[0].data) + low).requires_grad_(True)
|
||||
Xh = (torch.zeros_like(A[0].data) + high).requires_grad_(True)
|
||||
|
||||
if len(list(self.layers)) > 2:
|
||||
# unsafe way to check for embedding layer
|
||||
embed = list(self.layers[0])[0]
|
||||
conv = list(self.layers[0])[1]
|
||||
|
||||
layer = nn.Sequential(embed, conv)
|
||||
layerl = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(min=0)))
|
||||
layerh = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(max=0)))
|
||||
|
||||
else:
|
||||
layer = list(self.layers[0])[0]
|
||||
layerl = newconv2d(layer, lambda p: p.clamp(min=0))
|
||||
layerh = newconv2d(layer, lambda p: p.clamp(max=0))
|
||||
|
||||
|
||||
z = layer(A[0])
|
||||
z -= layerl(Xl) + layerh(Xh)
|
||||
s = (R[1] / z).data
|
||||
(z * s).sum().backward()
|
||||
c, cp, cm = A[0].grad, Xl.grad, Xh.grad
|
||||
|
||||
R[0] = (A[0] * c + Xl * cp + Xh * cm)
|
||||
#R[0] = (A[0] * c).data
|
||||
|
||||
return X, R[0].mean(dim=1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def create_3x3_conv_plan(num_layers : int,
|
||||
f_stretch : int,
|
||||
f_down : int,
|
||||
t_stretch : int,
|
||||
t_down : int
|
||||
):
|
||||
|
||||
|
||||
""" creates a stride, dilation, padding plan for a 2d conv network
|
||||
|
||||
Args:
|
||||
num_layers (int): number of layers
|
||||
f_stretch (int): log_2 of stretching factor along frequency axis
|
||||
f_down (int): log_2 of downsampling factor along frequency axis
|
||||
t_stretch (int): log_2 of stretching factor along time axis
|
||||
t_down (int): log_2 of downsampling factor along time axis
|
||||
|
||||
Returns:
|
||||
list(list(tuple)): list containing entries [(stride_t, stride_f), (dilation_t, dilation_f), (padding_t, padding_f)]
|
||||
"""
|
||||
|
||||
assert num_layers > 0 and t_stretch >= 0 and t_down >= 0 and f_stretch >= 0 and f_down >= 0
|
||||
assert f_stretch < num_layers and t_stretch < num_layers
|
||||
|
||||
def process_dimension(n_layers, stretch, down):
|
||||
|
||||
stack_layers = n_layers - 1
|
||||
|
||||
stride_layers = min(min(down, stretch) , stack_layers)
|
||||
dilation_layers = max(min(stack_layers - stride_layers - 1, stretch - stride_layers), 0)
|
||||
final_stride = 2 ** (max(down - stride_layers, 0))
|
||||
|
||||
final_dilation = 1
|
||||
if stride_layers < stack_layers and stretch - stride_layers - dilation_layers > 0:
|
||||
final_dilation = 2
|
||||
|
||||
strides, dilations, paddings = [], [], []
|
||||
processed_layers = 0
|
||||
current_dilation = 1
|
||||
|
||||
for _ in range(stride_layers):
|
||||
# increase receptive field and downsample via stride = 2
|
||||
strides.append(2)
|
||||
dilations.append(1)
|
||||
paddings.append(1)
|
||||
processed_layers += 1
|
||||
|
||||
if processed_layers < stack_layers:
|
||||
strides.append(1)
|
||||
dilations.append(1)
|
||||
paddings.append(1)
|
||||
processed_layers += 1
|
||||
|
||||
for _ in range(dilation_layers):
|
||||
# increase receptive field via dilation = 2
|
||||
strides.append(1)
|
||||
current_dilation *= 2
|
||||
dilations.append(current_dilation)
|
||||
paddings.append(current_dilation)
|
||||
processed_layers += 1
|
||||
|
||||
while processed_layers < n_layers - 1:
|
||||
# fill up with std layers
|
||||
strides.append(1)
|
||||
dilations.append(current_dilation)
|
||||
paddings.append(current_dilation)
|
||||
processed_layers += 1
|
||||
|
||||
# final layer
|
||||
strides.append(final_stride)
|
||||
current_dilation * final_dilation
|
||||
dilations.append(current_dilation)
|
||||
paddings.append(current_dilation)
|
||||
processed_layers += 1
|
||||
|
||||
assert processed_layers == n_layers
|
||||
|
||||
return strides, dilations, paddings
|
||||
|
||||
t_strides, t_dilations, t_paddings = process_dimension(num_layers, t_stretch, t_down)
|
||||
f_strides, f_dilations, f_paddings = process_dimension(num_layers, f_stretch, f_down)
|
||||
|
||||
plan = []
|
||||
|
||||
for i in range(num_layers):
|
||||
plan.append([
|
||||
(f_strides[i], t_strides[i]),
|
||||
(f_dilations[i], t_dilations[i]),
|
||||
(f_paddings[i], t_paddings[i]),
|
||||
])
|
||||
|
||||
return plan
|
||||
|
||||
|
||||
class DiscriminatorExperimental(SpecDiscriminatorBase):
|
||||
|
||||
def __init__(self,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
num_channels=16,
|
||||
max_channels=512,
|
||||
num_layers=5,
|
||||
use_spectral_norm=False):
|
||||
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
|
||||
self.num_channels = num_channels
|
||||
self.num_channels_max = max_channels
|
||||
self.num_layers = num_layers
|
||||
|
||||
layers = []
|
||||
stride = (2, 1)
|
||||
padding= (1, 1)
|
||||
in_channels = 1 + 2
|
||||
out_channels = self.num_channels
|
||||
for _ in range(self.num_layers):
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
)
|
||||
in_channels = out_channels + 2
|
||||
out_channels = min(2 * out_channels, self.num_channels_max)
|
||||
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||
|
||||
# bias biases
|
||||
bias_val = 0.1
|
||||
with torch.no_grad():
|
||||
for name, weight in self.named_parameters():
|
||||
if 'bias' in name:
|
||||
weight = weight + bias_val
|
||||
|
||||
|
||||
configs = {
|
||||
'f_down': {
|
||||
'stretch' : {
|
||||
64 : (0, 0),
|
||||
128: (1, 0),
|
||||
256: (2, 0),
|
||||
512: (3, 0),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
},
|
||||
'down' : {
|
||||
64 : (0, 0),
|
||||
128: (1, 0),
|
||||
256: (2, 0),
|
||||
512: (3, 0),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
}
|
||||
},
|
||||
'ft_down': {
|
||||
'stretch' : {
|
||||
64 : (0, 4),
|
||||
128: (1, 3),
|
||||
256: (2, 2),
|
||||
512: (3, 1),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
},
|
||||
'down' : {
|
||||
64 : (0, 4),
|
||||
128: (1, 3),
|
||||
256: (2, 2),
|
||||
512: (3, 1),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
}
|
||||
},
|
||||
'dilated': {
|
||||
'stretch' : {
|
||||
64 : (0, 4),
|
||||
128: (1, 3),
|
||||
256: (2, 2),
|
||||
512: (3, 1),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
},
|
||||
'down' : {
|
||||
64 : (0, 0),
|
||||
128: (0, 0),
|
||||
256: (0, 0),
|
||||
512: (0, 0),
|
||||
1024: (0, 0),
|
||||
2048: (0, 0)
|
||||
}
|
||||
},
|
||||
'mixed': {
|
||||
'stretch' : {
|
||||
64 : (0, 4),
|
||||
128: (1, 3),
|
||||
256: (2, 2),
|
||||
512: (3, 1),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
},
|
||||
'down' : {
|
||||
64 : (0, 0),
|
||||
128: (1, 0),
|
||||
256: (2, 0),
|
||||
512: (3, 0),
|
||||
1024: (4, 0),
|
||||
2048: (5, 0)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DiscriminatorMagFree(SpecDiscriminatorBase):
|
||||
|
||||
def __init__(self,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
num_channels=16,
|
||||
max_channels=256,
|
||||
num_layers=5,
|
||||
use_spectral_norm=False,
|
||||
design=None):
|
||||
|
||||
if design is None:
|
||||
raise ValueError('error: arch required in DiscriminatorMagFree')
|
||||
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
|
||||
stretch = configs[design]['stretch'][resolution[0]]
|
||||
down = configs[design]['down'][resolution[0]]
|
||||
|
||||
self.num_channels = num_channels
|
||||
self.num_channels_max = max_channels
|
||||
self.num_layers = num_layers
|
||||
self.stretch = stretch
|
||||
self.down = down
|
||||
|
||||
layers = []
|
||||
plan = create_3x3_conv_plan(num_layers + 1, stretch[0], down[0], stretch[1], down[1])
|
||||
in_channels = 1 + 2
|
||||
out_channels = self.num_channels
|
||||
for i in range(self.num_layers):
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=plan[i][0], dilation=plan[i][1], padding=plan[i][2])),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
)
|
||||
in_channels = out_channels + 2
|
||||
# product over strides
|
||||
channel_factor = plan[i][0][0] * plan[i][0][1]
|
||||
out_channels = min(channel_factor * out_channels, self.num_channels_max)
|
||||
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, 1, (3, 3), stride=plan[-1][0], dilation=plan[-1][1], padding=plan[-1][2])),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
# for layer in layers:
|
||||
# print(layer)
|
||||
|
||||
# print("end\n\n")
|
||||
|
||||
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||
|
||||
# bias biases
|
||||
bias_val = 0.1
|
||||
with torch.no_grad():
|
||||
for name, weight in self.named_parameters():
|
||||
if 'bias' in name:
|
||||
weight = weight + bias_val
|
||||
|
||||
class DiscriminatorMagFreqPosition(SpecDiscriminatorBase):
|
||||
|
||||
def __init__(self,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
num_channels=16,
|
||||
max_channels=512,
|
||||
num_layers=5,
|
||||
use_spectral_norm=False):
|
||||
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
|
||||
self.num_channels = num_channels
|
||||
self.num_channels_max = max_channels
|
||||
self.num_layers = num_layers
|
||||
|
||||
layers = []
|
||||
stride = (2, 1)
|
||||
padding= (1, 1)
|
||||
in_channels = 1 + 2
|
||||
out_channels = self.num_channels
|
||||
for _ in range(self.num_layers):
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
)
|
||||
in_channels = out_channels + 2
|
||||
out_channels = min(2 * out_channels, self.num_channels_max)
|
||||
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
FrequencyPositionalEmbedding(),
|
||||
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||
|
||||
|
||||
|
||||
class DiscriminatorMag2dPositional(SpecDiscriminatorBase):
|
||||
|
||||
def __init__(self,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
num_channels=16,
|
||||
max_channels=512,
|
||||
num_layers=5,
|
||||
d=5,
|
||||
use_spectral_norm=False):
|
||||
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.resolution = resolution
|
||||
self.num_channels = num_channels
|
||||
self.num_channels_max = max_channels
|
||||
self.num_layers = num_layers
|
||||
self.d = d
|
||||
embedding_dim = 4 * d
|
||||
|
||||
|
||||
layers = []
|
||||
stride = (2, 2)
|
||||
padding= (1, 1)
|
||||
in_channels = 1 + embedding_dim
|
||||
out_channels = self.num_channels
|
||||
for _ in range(self.num_layers):
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
PositionalEmbedding2D(d),
|
||||
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
)
|
||||
in_channels = out_channels + embedding_dim
|
||||
out_channels = min(2 * out_channels, self.num_channels_max)
|
||||
|
||||
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
PositionalEmbedding2D(),
|
||||
norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||
|
||||
|
||||
|
||||
class DiscriminatorMag(SpecDiscriminatorBase):
|
||||
def __init__(self,
|
||||
resolution,
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
num_channels=32,
|
||||
num_layers=5,
|
||||
use_spectral_norm=False):
|
||||
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
|
||||
self.num_channels = num_channels
|
||||
self.num_layers = num_layers
|
||||
|
||||
layers = []
|
||||
stride = (1, 1)
|
||||
padding= (1, 1)
|
||||
in_channels = 1
|
||||
out_channels = self.num_channels
|
||||
for _ in range(self.num_layers):
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
)
|
||||
in_channels = out_channels
|
||||
|
||||
layers.append(norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)))
|
||||
|
||||
super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
|
||||
|
||||
|
||||
discriminators = {
|
||||
'mag': DiscriminatorMag,
|
||||
'freqpos': DiscriminatorMagFreqPosition,
|
||||
'2dpos': DiscriminatorMag2dPositional,
|
||||
'experimental': DiscriminatorExperimental,
|
||||
'free': DiscriminatorMagFree
|
||||
}
|
||||
|
||||
class TFDMultiResolutionDiscriminator(torch.nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes_16k=[64, 128, 256, 512, 1024, 2048],
|
||||
architecture='mag',
|
||||
fs=16000,
|
||||
freq_roi=[50, 7400],
|
||||
noise_gain=0,
|
||||
use_spectral_norm=False,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
fft_sizes = [int(round(fft_size_16k * fs / 16000)) for fft_size_16k in fft_sizes_16k]
|
||||
|
||||
resolutions = [[n_fft, n_fft // 4, n_fft] for n_fft in fft_sizes]
|
||||
|
||||
|
||||
Disc = discriminators[architecture]
|
||||
|
||||
discs = [Disc(resolutions[i], fs=fs, freq_roi=freq_roi, noise_gain=noise_gain, use_spectral_norm=use_spectral_norm, **kwargs) for i in range(len(resolutions))]
|
||||
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y):
|
||||
outputs = []
|
||||
|
||||
for disc in self.discriminators:
|
||||
outputs.append(disc(y))
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FWGAN_disc_wrapper(nn.Module):
|
||||
def __init__(self, disc):
|
||||
super().__init__()
|
||||
|
||||
self.disc = disc
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
|
||||
out_real = self.disc(y)
|
||||
out_fake = self.disc(y_hat)
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for y_real, y_fake in zip(out_real, out_fake):
|
||||
y_d_rs.append(y_real[-1])
|
||||
y_d_gs.append(y_fake[-1])
|
||||
fmap_rs.append(y_real[:-1])
|
||||
fmap_gs.append(y_fake[:-1])
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
190
managed_components/78__esp-opus/dnn/torch/osce/models/lace.py
Normal file
190
managed_components/78__esp-opus/dnn/torch/osce/models/lace.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
|
||||
from models.nns_base import NNSBase
|
||||
from models.silk_feature_net_pl import SilkFeatureNetPL
|
||||
from models.silk_feature_net import SilkFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
import sys
|
||||
sys.path.append('../dnntools')
|
||||
|
||||
from dnntools.sparsification import create_sparsifier
|
||||
|
||||
|
||||
class LACE(NNSBase):
|
||||
""" Linear-Adaptive Coding Enhancer """
|
||||
FRAME_SIZE=80
|
||||
|
||||
def __init__(self,
|
||||
num_features=47,
|
||||
pitch_embedding_dim=64,
|
||||
cond_dim=256,
|
||||
pitch_max=257,
|
||||
kernel_size=15,
|
||||
preemph=0.85,
|
||||
skip=91,
|
||||
comb_gain_limit_db=-6,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
conv_gain_limits_db=[-6, 6],
|
||||
numbits_range=[50, 650],
|
||||
numbits_embedding_dim=8,
|
||||
hidden_feature_dim=64,
|
||||
partial_lookahead=True,
|
||||
norm_p=2,
|
||||
softquant=False,
|
||||
sparsify=False,
|
||||
sparsification_schedule=[10000, 30000, 100],
|
||||
sparsification_density=0.5,
|
||||
apply_weight_norm=False):
|
||||
|
||||
super().__init__(skip=skip, preemph=preemph)
|
||||
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
self.pitch_embedding_dim = pitch_embedding_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.preemph = preemph
|
||||
self.skip = skip
|
||||
self.numbits_range = numbits_range
|
||||
self.numbits_embedding_dim = numbits_embedding_dim
|
||||
self.hidden_feature_dim = hidden_feature_dim
|
||||
self.partial_lookahead = partial_lookahead
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
# numbits embedding
|
||||
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
|
||||
|
||||
# feature net
|
||||
if partial_lookahead:
|
||||
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
|
||||
else:
|
||||
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
|
||||
|
||||
# comb filters
|
||||
left_pad = self.kernel_size // 2
|
||||
right_pad = self.kernel_size - 1 - left_pad
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# spectral shaping
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
if sparsify:
|
||||
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
frame_rate = rate / self.FRAME_SIZE
|
||||
|
||||
# feature net
|
||||
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
|
||||
af_flops = self.af1.flop_count(rate)
|
||||
|
||||
if verbose:
|
||||
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
|
||||
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||
|
||||
return feature_net_flops + comb_flops + af_flops
|
||||
|
||||
def forward(self, x, features, periods, numbits, debug=False):
|
||||
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
|
||||
|
||||
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
|
||||
y = self.cf1(x, cf, periods, debug=debug)
|
||||
|
||||
y = self.cf2(y, cf, periods, debug=debug)
|
||||
|
||||
y = self.af1(y, cf, debug=debug)
|
||||
|
||||
return y
|
||||
|
||||
def get_impulse_responses(self, features, periods, numbits):
|
||||
""" generates impoulse responses on frame centers (input without batch dimension) """
|
||||
|
||||
num_frames = features.size(0)
|
||||
batch_size = 32
|
||||
max_len = 2 * (self.pitch_max + self.kernel_size) + 10
|
||||
|
||||
# spread out some pulses
|
||||
x = np.zeros((batch_size, 1, num_frames * self.FRAME_SIZE))
|
||||
for b in range(batch_size):
|
||||
x[b, :, self.FRAME_SIZE // 2 + b * self.FRAME_SIZE :: batch_size * self.FRAME_SIZE] = 1
|
||||
|
||||
# prepare input
|
||||
x = torch.from_numpy(x).float().to(features.device)
|
||||
features = torch.repeat_interleave(features.unsqueeze(0), batch_size, 0)
|
||||
periods = torch.repeat_interleave(periods.unsqueeze(0), batch_size, 0)
|
||||
numbits = torch.repeat_interleave(numbits.unsqueeze(0), batch_size, 0)
|
||||
|
||||
# run network
|
||||
with torch.no_grad():
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
|
||||
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
y = self.cf1(x, cf, periods, debug=False)
|
||||
y = self.cf2(y, cf, periods, debug=False)
|
||||
y = self.af1(y, cf, debug=False)
|
||||
|
||||
# collect responses
|
||||
y = y.detach().squeeze().cpu().numpy()
|
||||
cut_frames = (max_len + self.FRAME_SIZE - 1) // self.FRAME_SIZE
|
||||
num_responses = num_frames - cut_frames
|
||||
responses = np.zeros((num_responses, max_len))
|
||||
|
||||
for i in range(num_responses):
|
||||
b = i % batch_size
|
||||
start = self.FRAME_SIZE // 2 + i * self.FRAME_SIZE
|
||||
stop = start + max_len
|
||||
|
||||
responses[i, :] = y[b, start:stop]
|
||||
|
||||
return responses
|
||||
274
managed_components/78__esp-opus/dnn/torch/osce/models/lavoce.py
Normal file
274
managed_components/78__esp-opus/dnn/torch/osce/models/lavoce.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
from utils.layers.td_shaper import TDShaper
|
||||
from utils.layers.noise_shaper import NoiseShaper
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
from utils.endoscopy import write_data
|
||||
|
||||
from models.nns_base import NNSBase
|
||||
from models.lpcnet_feature_net import LPCNetFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
def print_channels(y, prefix="", name="", rate=16000):
|
||||
num_channels = y.size(1)
|
||||
for i in range(num_channels):
|
||||
channel_name = f"{prefix}_c{i:02d}"
|
||||
if len(name) > 0: channel_name += "_" + name
|
||||
ch = y[0,i,:].detach().cpu().numpy()
|
||||
ch = ((2**14) * ch / np.max(ch)).astype(np.int16)
|
||||
write_data(channel_name, ch, rate)
|
||||
|
||||
|
||||
|
||||
class LaVoce(nn.Module):
|
||||
""" Linear-Adaptive VOCodEr """
|
||||
FEATURE_FRAME_SIZE=160
|
||||
FRAME_SIZE=80
|
||||
|
||||
def __init__(self,
|
||||
num_features=20,
|
||||
pitch_embedding_dim=64,
|
||||
cond_dim=256,
|
||||
pitch_max=300,
|
||||
kernel_size=15,
|
||||
preemph=0.85,
|
||||
comb_gain_limit_db=-6,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
conv_gain_limits_db=[-6, 6],
|
||||
norm_p=2,
|
||||
avg_pool_k=4,
|
||||
pulses=False,
|
||||
innovate1=True,
|
||||
innovate2=False,
|
||||
innovate3=False,
|
||||
ftrans_k=2):
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
self.pitch_embedding_dim = pitch_embedding_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.preemph = preemph
|
||||
self.pulses = pulses
|
||||
self.ftrans_k = ftrans_k
|
||||
|
||||
assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
|
||||
self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
# feature net
|
||||
self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
|
||||
|
||||
# noise shaper
|
||||
self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
|
||||
|
||||
# comb filters
|
||||
left_pad = self.kernel_size // 2
|
||||
right_pad = self.kernel_size - 1 - left_pad
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
|
||||
self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# spectral shaping
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# non-linear transforms
|
||||
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate1)
|
||||
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate2)
|
||||
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=innovate3)
|
||||
|
||||
# combinators
|
||||
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# feature transforms
|
||||
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
|
||||
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
|
||||
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
|
||||
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
|
||||
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, ftrans_k)
|
||||
|
||||
|
||||
def create_phase_signals(self, periods):
|
||||
|
||||
batch_size = periods.size(0)
|
||||
progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
|
||||
progression = torch.repeat_interleave(progression, batch_size, 0)
|
||||
|
||||
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
|
||||
chunks = []
|
||||
for sframe in range(periods.size(1)):
|
||||
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
|
||||
|
||||
if self.pulses:
|
||||
alpha = torch.cos(f).view(batch_size, 1, 1)
|
||||
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
|
||||
pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
|
||||
|
||||
chunk = torch.cat((pulse_a, pulse_b), dim = 1)
|
||||
else:
|
||||
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
|
||||
chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
|
||||
|
||||
phase0 = phase0 + self.FRAME_SIZE * f
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
phase_signals = torch.cat(chunks, dim=-1)
|
||||
|
||||
return phase_signals
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
frame_rate = rate / self.FRAME_SIZE
|
||||
|
||||
# feature net
|
||||
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
|
||||
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
|
||||
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
|
||||
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
|
||||
|
||||
if verbose:
|
||||
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
|
||||
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
|
||||
|
||||
return feature_net_flops + comb_flops + af_flops + feature_flops
|
||||
|
||||
def feature_transform(self, f, layer):
|
||||
f = f.permute(0, 2, 1)
|
||||
f = F.pad(f, [self.ftrans_k - 1, 0])
|
||||
f = torch.tanh(layer(f))
|
||||
return f.permute(0, 2, 1)
|
||||
|
||||
def forward(self, features, periods, debug=False):
|
||||
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
|
||||
full_features = torch.cat((features, pitch_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
|
||||
# upsample periods
|
||||
periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
|
||||
|
||||
# pre-net
|
||||
ref_phase = torch.tanh(self.create_phase_signals(periods))
|
||||
if debug: print_channels(ref_phase, prefix="lavoce_01", name="pulse")
|
||||
x = self.af_prescale(ref_phase, cf)
|
||||
noise = self.noise_shaper(cf)
|
||||
if debug: print_channels(torch.cat((x, noise), dim=1), prefix="lavoce_02", name="inputs")
|
||||
y = self.af_mix(torch.cat((x, noise), dim=1), cf)
|
||||
if debug: print_channels(y, prefix="lavoce_03", name="postselect1")
|
||||
|
||||
# temporal shaping + innovating
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape1(y[:, 1:2, :], cf)
|
||||
if debug: print_channels(y2, prefix="lavoce_04", name="postshape1")
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af2(y, cf, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_05", name="postselect2")
|
||||
cf = self.feature_transform(cf, self.post_af2)
|
||||
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape2(y[:, 1:2, :], cf)
|
||||
if debug: print_channels(y2, prefix="lavoce_06", name="postshape2")
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af3(y, cf, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_07", name="postmix1")
|
||||
cf = self.feature_transform(cf, self.post_af3)
|
||||
|
||||
# spectral shaping
|
||||
y = self.cf1(y, cf, periods, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_08", name="postcomb1")
|
||||
cf = self.feature_transform(cf, self.post_cf1)
|
||||
|
||||
y = self.cf2(y, cf, periods, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_09", name="postcomb2")
|
||||
cf = self.feature_transform(cf, self.post_cf2)
|
||||
|
||||
y = self.af1(y, cf, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_10", name="postselect3")
|
||||
cf = self.feature_transform(cf, self.post_af1)
|
||||
|
||||
# final temporal env adjustment
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape3(y[:, 1:2, :], cf)
|
||||
if debug: print_channels(y2, prefix="lavoce_11", name="postshape3")
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af4(y, cf, debug=debug)
|
||||
if debug: print_channels(y, prefix="lavoce_12", name="postmix2")
|
||||
|
||||
return y
|
||||
|
||||
def process(self, features, periods, debug=False):
|
||||
|
||||
self.eval()
|
||||
device = next(iter(self.parameters())).device
|
||||
with torch.no_grad():
|
||||
|
||||
# run model
|
||||
f = features.unsqueeze(0).to(device)
|
||||
p = periods.unsqueeze(0).to(device)
|
||||
|
||||
y = self.forward(f, p, debug=debug).squeeze()
|
||||
|
||||
# deemphasis
|
||||
if self.preemph > 0:
|
||||
for i in range(len(y) - 1):
|
||||
y[i + 1] += self.preemph * y[i]
|
||||
|
||||
# clip to valid range
|
||||
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
from utils.layers.td_shaper import TDShaper
|
||||
from utils.layers.noise_shaper import NoiseShaper
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
from utils.endoscopy import write_data
|
||||
|
||||
from models.nns_base import NNSBase
|
||||
from models.lpcnet_feature_net import LPCNetFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
class LaVoce400(nn.Module):
|
||||
""" Linear-Adaptive VOCodEr """
|
||||
FEATURE_FRAME_SIZE=160
|
||||
FRAME_SIZE=40
|
||||
|
||||
def __init__(self,
|
||||
num_features=20,
|
||||
pitch_embedding_dim=64,
|
||||
cond_dim=256,
|
||||
pitch_max=300,
|
||||
kernel_size=15,
|
||||
preemph=0.85,
|
||||
comb_gain_limit_db=-6,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
conv_gain_limits_db=[-6, 6],
|
||||
norm_p=2,
|
||||
avg_pool_k=4,
|
||||
pulses=False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
self.pitch_embedding_dim = pitch_embedding_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.preemph = preemph
|
||||
self.pulses = pulses
|
||||
|
||||
assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
|
||||
self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
# feature net
|
||||
self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor)
|
||||
|
||||
# noise shaper
|
||||
self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE)
|
||||
|
||||
# comb filters
|
||||
left_pad = self.kernel_size // 2
|
||||
right_pad = self.kernel_size - 1 - left_pad
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
|
||||
self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# spectral shaping
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# non-linear transforms
|
||||
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=True)
|
||||
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
|
||||
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
|
||||
|
||||
# combinators
|
||||
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
# feature transforms
|
||||
self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2)
|
||||
|
||||
|
||||
def create_phase_signals(self, periods):
|
||||
|
||||
batch_size = periods.size(0)
|
||||
progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
|
||||
progression = torch.repeat_interleave(progression, batch_size, 0)
|
||||
|
||||
phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
|
||||
chunks = []
|
||||
for sframe in range(periods.size(1)):
|
||||
f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
|
||||
|
||||
if self.pulses:
|
||||
alpha = torch.cos(f).view(batch_size, 1, 1)
|
||||
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
|
||||
pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
|
||||
|
||||
chunk = torch.cat((pulse_a, pulse_b), dim = 1)
|
||||
else:
|
||||
chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
|
||||
|
||||
chunk = torch.cat((chunk_sin, chunk_cos), dim = 1)
|
||||
|
||||
phase0 = phase0 + self.FRAME_SIZE * f
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
phase_signals = torch.cat(chunks, dim=-1)
|
||||
|
||||
return phase_signals
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
frame_rate = rate / self.FRAME_SIZE
|
||||
|
||||
# feature net
|
||||
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
|
||||
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate)
|
||||
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
|
||||
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
|
||||
|
||||
if verbose:
|
||||
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
|
||||
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
|
||||
|
||||
return feature_net_flops + comb_flops + af_flops + feature_flops
|
||||
|
||||
def feature_transform(self, f, layer):
|
||||
f = f.permute(0, 2, 1)
|
||||
f = F.pad(f, [1, 0])
|
||||
f = torch.tanh(layer(f))
|
||||
return f.permute(0, 2, 1)
|
||||
|
||||
def forward(self, features, periods, debug=False):
|
||||
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
|
||||
full_features = torch.cat((features, pitch_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
|
||||
# upsample periods
|
||||
periods = torch.repeat_interleave(periods, self.upsamp_factor, 1)
|
||||
|
||||
# pre-net
|
||||
ref_phase = torch.tanh(self.create_phase_signals(periods))
|
||||
x = self.af_prescale(ref_phase, cf)
|
||||
noise = self.noise_shaper(cf)
|
||||
y = self.af_mix(torch.cat((x, noise), dim=1), cf)
|
||||
|
||||
if debug:
|
||||
ch0 = y[0,0,:].detach().cpu().numpy()
|
||||
ch1 = y[0,1,:].detach().cpu().numpy()
|
||||
ch0 = (2**15 * ch0 / np.max(ch0)).astype(np.int16)
|
||||
ch1 = (2**15 * ch1 / np.max(ch1)).astype(np.int16)
|
||||
write_data('prior_channel0', ch0, 16000)
|
||||
write_data('prior_channel1', ch1, 16000)
|
||||
|
||||
# temporal shaping + innovating
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape1(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af2(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af2)
|
||||
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape2(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af3(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af3)
|
||||
|
||||
# spectral shaping
|
||||
y = self.cf1(y, cf, periods, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_cf1)
|
||||
|
||||
y = self.cf2(y, cf, periods, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_cf2)
|
||||
|
||||
y = self.af1(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af1)
|
||||
|
||||
# final temporal env adjustment
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape3(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af4(y, cf, debug=debug)
|
||||
|
||||
return y
|
||||
|
||||
def process(self, features, periods, debug=False):
|
||||
|
||||
self.eval()
|
||||
device = next(iter(self.parameters())).device
|
||||
with torch.no_grad():
|
||||
|
||||
# run model
|
||||
f = features.unsqueeze(0).to(device)
|
||||
p = periods.unsqueeze(0).to(device)
|
||||
|
||||
y = self.forward(f, p, debug=debug).squeeze()
|
||||
|
||||
# deemphasis
|
||||
if self.preemph > 0:
|
||||
for i in range(len(y) - 1):
|
||||
y[i + 1] += self.preemph * y[i]
|
||||
|
||||
# clip to valid range
|
||||
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
class LPCNetFeatureNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
feature_dim=84,
|
||||
num_channels=256,
|
||||
upsamp_factor=2,
|
||||
lookahead=True):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.num_channels = num_channels
|
||||
self.upsamp_factor = upsamp_factor
|
||||
self.lookahead = lookahead
|
||||
|
||||
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
|
||||
self.conv2 = nn.Conv1d(num_channels, num_channels, 3)
|
||||
|
||||
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
|
||||
|
||||
self.tconv = nn.ConvTranspose1d(num_channels, num_channels, upsamp_factor, upsamp_factor)
|
||||
|
||||
def flop_count(self, rate=100):
|
||||
count = 0
|
||||
for conv in self.conv1, self.conv2, self.tconv:
|
||||
count += _conv1d_flop_count(conv, rate)
|
||||
|
||||
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def forward(self, features, state=None):
|
||||
""" features shape: (batch_size, num_frames, feature_dim) """
|
||||
|
||||
batch_size = features.size(0)
|
||||
|
||||
if state is None:
|
||||
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
|
||||
|
||||
|
||||
features = features.permute(0, 2, 1)
|
||||
if self.lookahead:
|
||||
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
|
||||
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
|
||||
else:
|
||||
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
|
||||
c = torch.tanh(self.conv2(F.pad(c, [2, 0])))
|
||||
|
||||
c = torch.tanh(self.tconv(c))
|
||||
|
||||
c = c.permute(0, 2, 1)
|
||||
|
||||
c, _ = self.gru(c, state)
|
||||
|
||||
return c
|
||||
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class NNSBase(nn.Module):
|
||||
|
||||
def __init__(self, skip=91, preemph=0.85):
|
||||
super().__init__()
|
||||
|
||||
self.skip = skip
|
||||
self.preemph = preemph
|
||||
|
||||
def process(self, sig, features, periods, numbits, debug=False):
|
||||
|
||||
self.eval()
|
||||
has_numbits = 'numbits' in self.forward.__code__.co_varnames
|
||||
device = next(iter(self.parameters())).device
|
||||
with torch.no_grad():
|
||||
|
||||
# run model
|
||||
x = sig.view(1, 1, -1).to(device)
|
||||
f = features.unsqueeze(0).to(device)
|
||||
p = periods.unsqueeze(0).to(device)
|
||||
n = numbits.unsqueeze(0).to(device)
|
||||
|
||||
if has_numbits:
|
||||
y = self.forward(x, f, p, n, debug=debug).squeeze()
|
||||
else:
|
||||
y = self.forward(x, f, p, debug=debug).squeeze()
|
||||
|
||||
# deemphasis
|
||||
if self.preemph > 0:
|
||||
for i in range(len(y) - 1):
|
||||
y[i + 1] += self.preemph * y[i]
|
||||
|
||||
# delay compensation
|
||||
y = torch.cat((y[self.skip:], torch.zeros(self.skip, dtype=y.dtype, device=y.device)))
|
||||
out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short()
|
||||
|
||||
return out
|
||||
218
managed_components/78__esp-opus/dnn/torch/osce/models/no_lace.py
Normal file
218
managed_components/78__esp-opus/dnn/torch/osce/models/no_lace.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numbers
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
from utils.layers.td_shaper import TDShaper
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
from models.nns_base import NNSBase
|
||||
from models.silk_feature_net_pl import SilkFeatureNetPL
|
||||
from models.silk_feature_net import SilkFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
import sys
|
||||
sys.path.append('../dnntools')
|
||||
from dnntools.quantization import soft_quant
|
||||
from dnntools.sparsification import create_sparsifier, mark_for_sparsification
|
||||
|
||||
class NoLACE(NNSBase):
|
||||
""" Non-Linear Adaptive Coding Enhancer """
|
||||
FRAME_SIZE=80
|
||||
|
||||
def __init__(self,
|
||||
num_features=47,
|
||||
pitch_embedding_dim=64,
|
||||
cond_dim=256,
|
||||
pitch_max=257,
|
||||
kernel_size=15,
|
||||
preemph=0.85,
|
||||
skip=91,
|
||||
comb_gain_limit_db=-6,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
conv_gain_limits_db=[-6, 6],
|
||||
numbits_range=[50, 650],
|
||||
numbits_embedding_dim=8,
|
||||
hidden_feature_dim=64,
|
||||
partial_lookahead=True,
|
||||
norm_p=2,
|
||||
avg_pool_k=4,
|
||||
pool_after=False,
|
||||
softquant=False,
|
||||
sparsify=False,
|
||||
sparsification_schedule=[100, 1000, 100],
|
||||
sparsification_density=0.5,
|
||||
apply_weight_norm=False):
|
||||
|
||||
super().__init__(skip=skip, preemph=preemph)
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
self.pitch_embedding_dim = pitch_embedding_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.preemph = preemph
|
||||
self.skip = skip
|
||||
self.numbits_range = numbits_range
|
||||
self.numbits_embedding_dim = numbits_embedding_dim
|
||||
self.hidden_feature_dim = hidden_feature_dim
|
||||
self.partial_lookahead = partial_lookahead
|
||||
|
||||
if isinstance(sparsification_density, numbers.Number):
|
||||
sparsification_density = 10 * [sparsification_density]
|
||||
|
||||
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
# numbits embedding
|
||||
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
|
||||
|
||||
# feature net
|
||||
if partial_lookahead:
|
||||
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm)
|
||||
else:
|
||||
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
|
||||
|
||||
# comb filters
|
||||
left_pad = self.kernel_size // 2
|
||||
right_pad = self.kernel_size - 1 - left_pad
|
||||
self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# spectral shaping
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# non-linear transforms
|
||||
self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# combinators
|
||||
self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm)
|
||||
|
||||
# feature transforms
|
||||
self.post_cf1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_cf2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_af1 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_af2 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
self.post_af3 = norm(nn.Conv1d(cond_dim, cond_dim, 2))
|
||||
|
||||
if softquant:
|
||||
self.post_cf1 = soft_quant(self.post_cf1)
|
||||
self.post_cf2 = soft_quant(self.post_cf2)
|
||||
self.post_af1 = soft_quant(self.post_af1)
|
||||
self.post_af2 = soft_quant(self.post_af2)
|
||||
self.post_af3 = soft_quant(self.post_af3)
|
||||
|
||||
|
||||
if sparsify:
|
||||
mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4]))
|
||||
mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4]))
|
||||
mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4]))
|
||||
mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4]))
|
||||
mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4]))
|
||||
|
||||
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
frame_rate = rate / self.FRAME_SIZE
|
||||
|
||||
# feature net
|
||||
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||
comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
|
||||
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate)
|
||||
shape_flops = self.tdshape1.flop_count(rate) + self.tdshape2.flop_count(rate) + self.tdshape3.flop_count(rate)
|
||||
feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
|
||||
+ _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
|
||||
|
||||
if verbose:
|
||||
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||
print(f"comb filters: {comb_flops / 1e6} MFLOPS")
|
||||
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||
print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
|
||||
|
||||
return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops
|
||||
|
||||
def feature_transform(self, f, layer):
|
||||
f0 = f.permute(0, 2, 1)
|
||||
f = F.pad(f0, [1, 0])
|
||||
f = torch.tanh(layer(f))
|
||||
return f.permute(0, 2, 1)
|
||||
|
||||
def forward(self, x, features, periods, numbits, debug=False):
|
||||
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
|
||||
|
||||
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
|
||||
y = self.cf1(x, cf, periods, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_cf1)
|
||||
|
||||
y = self.cf2(y, cf, periods, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_cf2)
|
||||
|
||||
y = self.af1(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af1)
|
||||
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape1(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af2(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af2)
|
||||
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape2(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af3(y, cf, debug=debug)
|
||||
cf = self.feature_transform(cf, self.post_af3)
|
||||
|
||||
y1 = y[:, 0:1, :]
|
||||
y2 = self.tdshape3(y[:, 1:2, :], cf)
|
||||
y = torch.cat((y1, y2), dim=1)
|
||||
y = self.af4(y, cf, debug=debug)
|
||||
|
||||
return y
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import math as m
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ScaleEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
min_val,
|
||||
max_val,
|
||||
logscale=False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if min_val >= max_val:
|
||||
raise ValueError('min_val must be smaller than max_val')
|
||||
|
||||
if min_val <= 0 and logscale:
|
||||
raise ValueError('min_val must be positive when logscale is true')
|
||||
|
||||
self.dim = dim
|
||||
self.logscale = logscale
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
|
||||
if logscale:
|
||||
self.min_val = m.log(self.min_val)
|
||||
self.max_val = m.log(self.max_val)
|
||||
|
||||
|
||||
self.offset = (self.min_val + self.max_val) / 2
|
||||
self.scale_factors = nn.Parameter(
|
||||
torch.arange(1, dim+1, dtype=torch.float32) * torch.pi / (self.max_val - self.min_val)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.logscale: x = torch.log(x)
|
||||
x = torch.clip(x, self.min_val, self.max_val) - self.offset
|
||||
return torch.sin(x.unsqueeze(-1) * self.scale_factors - 0.5)
|
||||
@@ -0,0 +1,179 @@
|
||||
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.silk_upsampler import SilkUpsampler
|
||||
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
|
||||
from utils.layers.td_shaper import TDShaper
|
||||
from utils.layers.deemph import Deemph
|
||||
from utils.misc import freeze_model
|
||||
|
||||
from models.nns_base import NNSBase
|
||||
from models.silk_feature_net_pl import SilkFeatureNetPL
|
||||
from models.silk_feature_net import SilkFeatureNet
|
||||
from .scale_embedding import ScaleEmbedding
|
||||
|
||||
|
||||
|
||||
class ShapeUp48(NNSBase):
|
||||
FRAME_SIZE16k=80
|
||||
|
||||
def __init__(self,
|
||||
num_features=47,
|
||||
pitch_embedding_dim=64,
|
||||
cond_dim=256,
|
||||
pitch_max=257,
|
||||
kernel_size=15,
|
||||
preemph=0.85,
|
||||
skip=288,
|
||||
conv_gain_limits_db=[-6, 6],
|
||||
numbits_range=[50, 650],
|
||||
numbits_embedding_dim=8,
|
||||
hidden_feature_dim=64,
|
||||
partial_lookahead=True,
|
||||
norm_p=2,
|
||||
target_fs=48000,
|
||||
noise_amplitude=0,
|
||||
prenet=None,
|
||||
avg_pool_k=4):
|
||||
|
||||
super().__init__(skip=skip, preemph=preemph)
|
||||
|
||||
|
||||
self.num_features = num_features
|
||||
self.cond_dim = cond_dim
|
||||
self.pitch_max = pitch_max
|
||||
self.pitch_embedding_dim = pitch_embedding_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.preemph = preemph
|
||||
self.skip = skip
|
||||
self.numbits_range = numbits_range
|
||||
self.numbits_embedding_dim = numbits_embedding_dim
|
||||
self.hidden_feature_dim = hidden_feature_dim
|
||||
self.partial_lookahead = partial_lookahead
|
||||
self.frame_size48 = int(self.FRAME_SIZE16k * target_fs / 16000 + .1)
|
||||
self.frame_size32 = self.FRAME_SIZE16k * 2
|
||||
self.noise_amplitude = noise_amplitude
|
||||
self.prenet = prenet
|
||||
|
||||
# freeze prenet if given
|
||||
if prenet is not None:
|
||||
freeze_model(self.prenet)
|
||||
try:
|
||||
self.deemph = Deemph(prenet.preemph)
|
||||
except:
|
||||
print("[warning] prenet model is expected to have preemph attribute")
|
||||
self.deemph = Deemph(0)
|
||||
|
||||
|
||||
|
||||
# upsampler
|
||||
self.upsampler = SilkUpsampler()
|
||||
|
||||
# pitch embedding
|
||||
self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
|
||||
|
||||
# numbits embedding
|
||||
self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
|
||||
|
||||
# feature net
|
||||
if partial_lookahead:
|
||||
self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
|
||||
else:
|
||||
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
|
||||
|
||||
# non-linear transforms
|
||||
self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k)
|
||||
self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k)
|
||||
|
||||
# spectral shaping
|
||||
self.af_noise = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=[-30, 0], norm_p=norm_p)
|
||||
self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af2 = LimitedAdaptiveConv1d(3, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
|
||||
|
||||
|
||||
def flop_count(self, rate=16000, verbose=False):
|
||||
|
||||
frame_rate = rate / self.FRAME_SIZE16k
|
||||
|
||||
# feature net
|
||||
feature_net_flops = self.feature_net.flop_count(frame_rate)
|
||||
af_flops = self.af1.flop_count(rate) + self.af2.flop_count(2 * rate) + self.af3.flop_count(3 * rate)
|
||||
|
||||
if verbose:
|
||||
print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
|
||||
print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
|
||||
|
||||
return feature_net_flops + af_flops
|
||||
|
||||
def forward(self, x, features, periods, numbits, debug=False):
|
||||
|
||||
if self.prenet is not None:
|
||||
with torch.no_grad():
|
||||
x = self.prenet(x, features, periods, numbits)
|
||||
x = self.deemph(x)
|
||||
|
||||
|
||||
|
||||
periods = periods.squeeze(-1)
|
||||
pitch_embedding = self.pitch_embedding(periods)
|
||||
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
|
||||
|
||||
full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
|
||||
cf = self.feature_net(full_features)
|
||||
|
||||
y32 = self.upsampler.hq_2x_up(x)
|
||||
|
||||
noise = self.noise_amplitude * torch.randn_like(y32)
|
||||
noise = self.af_noise(noise, cf)
|
||||
|
||||
y32 = self.af1(y32, cf, debug=debug)
|
||||
|
||||
y32_1 = y32[:, 0:1, :]
|
||||
y32_2 = self.tdshape1(y32[:, 1:2, :], cf)
|
||||
y32 = torch.cat((y32_1, y32_2, noise), dim=1)
|
||||
|
||||
y32 = self.af2(y32, cf, debug=debug)
|
||||
|
||||
y48 = self.upsampler.interpolate_3_2(y32)
|
||||
|
||||
y48_1 = y48[:, 0:1, :]
|
||||
y48_2 = self.tdshape2(y48[:, 1:2, :], cf)
|
||||
y48 = torch.cat((y48_1, y48_2), dim=1)
|
||||
|
||||
y48 = self.af3(y48, cf, debug=debug)
|
||||
|
||||
return y48
|
||||
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
class SilkFeatureNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
feature_dim=47,
|
||||
num_channels=256,
|
||||
lookahead=False):
|
||||
|
||||
super(SilkFeatureNet, self).__init__()
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.num_channels = num_channels
|
||||
self.lookahead = lookahead
|
||||
|
||||
self.conv1 = nn.Conv1d(feature_dim, num_channels, 3)
|
||||
self.conv2 = nn.Conv1d(num_channels, num_channels, 3, dilation=2)
|
||||
|
||||
self.gru = nn.GRU(num_channels, num_channels, batch_first=True)
|
||||
|
||||
def flop_count(self, rate=200):
|
||||
count = 0
|
||||
for conv in self.conv1, self.conv2:
|
||||
count += _conv1d_flop_count(conv, rate)
|
||||
|
||||
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def forward(self, features, state=None):
|
||||
""" features shape: (batch_size, num_frames, feature_dim) """
|
||||
|
||||
batch_size = features.size(0)
|
||||
|
||||
if state is None:
|
||||
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
|
||||
|
||||
|
||||
features = features.permute(0, 2, 1)
|
||||
if self.lookahead:
|
||||
c = torch.tanh(self.conv1(F.pad(features, [1, 1])))
|
||||
c = torch.tanh(self.conv2(F.pad(c, [2, 2])))
|
||||
else:
|
||||
c = torch.tanh(self.conv1(F.pad(features, [2, 0])))
|
||||
c = torch.tanh(self.conv2(F.pad(c, [4, 0])))
|
||||
|
||||
c = c.permute(0, 2, 1)
|
||||
|
||||
c, _ = self.gru(c, state)
|
||||
|
||||
return c
|
||||
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
import sys
|
||||
sys.path.append('../dnntools')
|
||||
import numbers
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
from dnntools.quantization.softquant import soft_quant
|
||||
from dnntools.sparsification import mark_for_sparsification
|
||||
|
||||
class SilkFeatureNetPL(nn.Module):
|
||||
""" feature net with partial lookahead """
|
||||
def __init__(self,
|
||||
feature_dim=47,
|
||||
num_channels=256,
|
||||
hidden_feature_dim=64,
|
||||
softquant=False,
|
||||
sparsify=True,
|
||||
sparsification_density=0.5,
|
||||
apply_weight_norm=False):
|
||||
|
||||
super(SilkFeatureNetPL, self).__init__()
|
||||
|
||||
if isinstance(sparsification_density, numbers.Number):
|
||||
sparsification_density = 4 * [sparsification_density]
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.num_channels = num_channels
|
||||
self.hidden_feature_dim = hidden_feature_dim
|
||||
|
||||
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1))
|
||||
self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2))
|
||||
self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
|
||||
self.gru = norm(norm(nn.GRU(num_channels, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
|
||||
|
||||
if softquant:
|
||||
self.conv2 = soft_quant(self.conv2)
|
||||
self.tconv = soft_quant(self.tconv)
|
||||
self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
|
||||
|
||||
if sparsify:
|
||||
mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4]))
|
||||
mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
|
||||
mark_for_sparsification(
|
||||
self.gru,
|
||||
{
|
||||
'W_ir' : (sparsification_density[2], [8, 4], False),
|
||||
'W_iz' : (sparsification_density[2], [8, 4], False),
|
||||
'W_in' : (sparsification_density[2], [8, 4], False),
|
||||
'W_hr' : (sparsification_density[3], [8, 4], True),
|
||||
'W_hz' : (sparsification_density[3], [8, 4], True),
|
||||
'W_hn' : (sparsification_density[3], [8, 4], True),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def flop_count(self, rate=200):
|
||||
count = 0
|
||||
for conv in self.conv1, self.conv2, self.tconv:
|
||||
count += _conv1d_flop_count(conv, rate)
|
||||
|
||||
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def forward(self, features, state=None):
|
||||
""" features shape: (batch_size, num_frames, feature_dim) """
|
||||
|
||||
batch_size = features.size(0)
|
||||
num_frames = features.size(1)
|
||||
|
||||
if state is None:
|
||||
state = torch.zeros((1, batch_size, self.num_channels), device=features.device)
|
||||
|
||||
features = features.permute(0, 2, 1)
|
||||
# dimensionality reduction
|
||||
c = torch.tanh(self.conv1(features))
|
||||
|
||||
# frame accumulation
|
||||
c = c.permute(0, 2, 1)
|
||||
c = c.reshape(batch_size, num_frames // 4, -1).permute(0, 2, 1)
|
||||
c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
|
||||
|
||||
# upsampling
|
||||
c = torch.tanh(self.tconv(c))
|
||||
c = c.permute(0, 2, 1)
|
||||
|
||||
c, _ = self.gru(c, state)
|
||||
|
||||
return c
|
||||
@@ -0,0 +1,9 @@
|
||||
pyyaml==6.0.1
|
||||
torch==2.0.1
|
||||
numpy==1.25.2
|
||||
scipy==1.11.2
|
||||
pesq==0.0.4
|
||||
gitpython==3.1.41
|
||||
matplotlib==3.7.3
|
||||
torchaudio==2.0.2
|
||||
tqdm==4.66.1
|
||||
107716
managed_components/78__esp-opus/dnn/torch/osce/resources/training_files.txt
Normal file
107716
managed_components/78__esp-opus/dnn/torch/osce/resources/training_files.txt
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy.io import wavfile
|
||||
import resampy
|
||||
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("filelist", type=str, help="file with filenames for concatenation in WAVE format")
|
||||
parser.add_argument("target_fs", type=int, help="target sampling rate of concatenated file")
|
||||
parser.add_argument("output", type=str, help="binary output file (integer16)")
|
||||
parser.add_argument("--basedir", type=str, help="basedir for filenames in filelist, defaults to ./", default="./")
|
||||
parser.add_argument("--normalize", action="store_true", help="apply normalization")
|
||||
parser.add_argument("--db_max", type=float, help="max DB for random normalization", default=0)
|
||||
parser.add_argument("--db_min", type=float, help="min DB for random normalization", default=0)
|
||||
parser.add_argument("--verbose", action="store_true")
|
||||
|
||||
def read_filelist(basedir, filelist):
|
||||
with open(filelist, "r") as f:
|
||||
files = f.readlines()
|
||||
|
||||
fullfiles = [os.path.join(basedir, f.rstrip('\n')) for f in files if len(f.rstrip('\n')) > 0]
|
||||
|
||||
return fullfiles
|
||||
|
||||
def read_wave(file, target_fs):
|
||||
fs, x = wavfile.read(file)
|
||||
|
||||
if fs < target_fs:
|
||||
return None
|
||||
print(f"[read_wave] warning: file {file} will be up-sampled from {fs} to {target_fs} Hz")
|
||||
|
||||
if fs != target_fs:
|
||||
x = resampy.resample(x, fs, target_fs)
|
||||
|
||||
return x.astype(np.float32)
|
||||
|
||||
def random_normalize(x, db_min, db_max, max_val=2**15 - 1):
|
||||
db = np.random.uniform(db_min, db_max, 1)
|
||||
m = np.abs(x).max()
|
||||
c = 10**(db/20) * max_val / m
|
||||
|
||||
return c * x
|
||||
|
||||
|
||||
def concatenate(filelist : str, output : str, target_fs: int, normalize=True, db_min=0, db_max=0, verbose=False):
|
||||
|
||||
overlap_size = int(40 * target_fs / 8000)
|
||||
overlap_mem = np.zeros(overlap_size, dtype=np.float32)
|
||||
overlap_win1 = (0.5 + 0.5 * np.cos(np.arange(0, overlap_size) * np.pi / overlap_size)).astype(np.float32)
|
||||
overlap_win2 = np.flipud(overlap_win1)
|
||||
|
||||
with open(output, 'wb') as f:
|
||||
for file in filelist:
|
||||
x = read_wave(file, target_fs)
|
||||
if x is None: continue
|
||||
|
||||
if len(x) < 10 * overlap_size:
|
||||
if verbose: print(f"skipping {file}...")
|
||||
continue
|
||||
elif verbose:
|
||||
print(f"processing {file}...")
|
||||
|
||||
if normalize:
|
||||
x = random_normalize(x, db_min, db_max)
|
||||
|
||||
x1 = x[:-overlap_size]
|
||||
x1[:overlap_size] = overlap_win1 * overlap_mem + overlap_win2 * x1[:overlap_size]
|
||||
|
||||
f.write(x1.astype(np.int16).tobytes())
|
||||
|
||||
overlap_mem = x1[-overlap_size]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
filelist = read_filelist(args.basedir, args.filelist)
|
||||
|
||||
concatenate(filelist, args.output, args.target_fs, normalize=args.normalize, db_min=args.db_min, db_max=args.db_max, verbose=args.verbose)
|
||||
@@ -0,0 +1,28 @@
|
||||
import argparse
|
||||
|
||||
from scipy.io import wavfile
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from utils.layers.silk_upsampler import SilkUpsampler
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, help="input wave file")
|
||||
parser.add_argument("output", type=str, help="output wave file")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
fs, x = wavfile.read(args.input)
|
||||
|
||||
# being lazy for now
|
||||
assert fs == 16000 and x.dtype == np.int16
|
||||
|
||||
x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)
|
||||
|
||||
upsampler = SilkUpsampler()
|
||||
y = upsampler(x)
|
||||
|
||||
y = y.squeeze().numpy().astype(np.int16)
|
||||
|
||||
wavfile.write(args.output, 48000, y[13:])
|
||||
@@ -0,0 +1,123 @@
|
||||
import argparse
|
||||
import os
|
||||
import yaml
|
||||
import subprocess
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('commonvoice_base_dir')
|
||||
parser.add_argument('output_dir')
|
||||
parser.add_argument('--clips-per-language', required=False, type=int, default=10)
|
||||
parser.add_argument('--seed', required=False, type=int, default=2024)
|
||||
|
||||
|
||||
def select_clips(dir, num_clips=10):
|
||||
|
||||
if num_clips % 2:
|
||||
print(f"warning: number of clips will be reduced to {num_clips - 1}")
|
||||
female = dict()
|
||||
male = dict()
|
||||
|
||||
clips = np.genfromtxt(os.path.join(dir, 'validated.tsv'), delimiter='\t', dtype=str, invalid_raise=False)
|
||||
clips_by_client = dict()
|
||||
|
||||
if len(clips.shape) < 2 or len(clips) < num_clips:
|
||||
# not enough data to proceed
|
||||
return None
|
||||
|
||||
for client in set(clips[1:,0]):
|
||||
client_clips = clips[clips[:, 0] == client]
|
||||
f, m = False, False
|
||||
if 'female_feminine' in client_clips[:, 8]:
|
||||
female[client] = client_clips[client_clips[:, 8] == 'female_feminine']
|
||||
f = True
|
||||
if 'male_masculine' in client_clips[:, 8]:
|
||||
male[client] = client_clips[client_clips[:, 8] == 'male_masculine']
|
||||
m = True
|
||||
|
||||
if f and m:
|
||||
print(f"both male and female clips under client {client}")
|
||||
|
||||
|
||||
if min(len(female), len(male)) < num_clips // 2:
|
||||
return None
|
||||
|
||||
# select num_clips // 2 random female clients
|
||||
female_client_selection = np.array(list(female.keys()), dtype=str)[np.random.choice(len(female), num_clips//2, replace=False)]
|
||||
female_clip_selection = []
|
||||
for c in female_client_selection:
|
||||
s_idx = np.random.randint(0, len(female[c]))
|
||||
female_clip_selection.append(os.path.join(dir, 'clips', female[c][s_idx, 1].item()))
|
||||
|
||||
# select num_clips // 2 random female clients
|
||||
male_client_selection = np.array(list(male.keys()), dtype=str)[np.random.choice(len(male), num_clips//2, replace=False)]
|
||||
male_clip_selection = []
|
||||
for c in male_client_selection:
|
||||
s_idx = np.random.randint(0, len(male[c]))
|
||||
male_clip_selection.append(os.path.join(dir, 'clips', male[c][s_idx, 1].item()))
|
||||
|
||||
return female_clip_selection + male_clip_selection
|
||||
|
||||
def ffmpeg_available():
|
||||
try:
|
||||
x = subprocess.run(['ffmpeg', '-h'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
return x.returncode == 0
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def convert_clips(selection, outdir):
|
||||
if not ffmpeg_available():
|
||||
raise RuntimeError("ffmpeg not available")
|
||||
|
||||
clipdir = os.path.join(outdir, 'clips')
|
||||
os.makedirs(clipdir, exist_ok=True)
|
||||
|
||||
clipdict = dict()
|
||||
|
||||
for lang, clips in selection.items():
|
||||
clipdict[lang] = []
|
||||
for clip in clips:
|
||||
clipname = os.path.splitext(os.path.split(clip)[-1])[0]
|
||||
target_name = os.path.join('clips', clipname + '.wav')
|
||||
call_args = ['ffmpeg', '-i', clip, '-ar', '16000', os.path.join(outdir, target_name)]
|
||||
print(call_args)
|
||||
r = subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
if r.returncode != 0:
|
||||
raise RuntimeError(f'could not execute {call_args}')
|
||||
clipdict[lang].append(target_name)
|
||||
|
||||
return clipdict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not ffmpeg_available():
|
||||
raise RuntimeError("ffmpeg not available")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
base_dir = args.commonvoice_base_dir
|
||||
output_dir = args.output_dir
|
||||
seed = args.seed
|
||||
|
||||
np.random.seed(seed)
|
||||
|
||||
langs = os.listdir(base_dir)
|
||||
selection = dict()
|
||||
|
||||
for lang in langs:
|
||||
print(f"processing {lang}...")
|
||||
clips = select_clips(os.path.join(base_dir, lang))
|
||||
if clips is not None:
|
||||
selection[lang] = clips
|
||||
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
clips = convert_clips(selection, output_dir)
|
||||
|
||||
with open(os.path.join(output_dir, 'clips.yml'), 'w') as f:
|
||||
yaml.dump(clips, f)
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
INPUT="dataset/LibriSpeech"
|
||||
OUTPUT="testdata"
|
||||
OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
|
||||
BITRATES=( 6000 7500 ) # 9000 12000 15000 18000 24000 32000 )
|
||||
|
||||
|
||||
mkdir -p $OUTPUT
|
||||
|
||||
for fn in $(find $INPUT -name "*.wav")
|
||||
do
|
||||
name=$(basename ${fn%*.wav})
|
||||
sox $fn -r 16000 -b 16 -e signed-integer ${OUTPUT}/tmp.raw
|
||||
for br in ${BITRATES[@]}
|
||||
do
|
||||
folder=${OUTPUT}/"${name}_${br}.se"
|
||||
echo "creating ${folder}..."
|
||||
mkdir -p $folder
|
||||
cp ${OUTPUT}/tmp.raw ${folder}/clean.s16
|
||||
(cd ${folder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
|
||||
done
|
||||
rm -f ${OUTPUT}/tmp.raw
|
||||
done
|
||||
@@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
export PYTHON=/home/ubuntu/opt/miniconda3/envs/torch/bin/python
|
||||
export LACE="/local/experiments/ietf_enhancement_studies/checkpoints/lace_checkpoint.pth"
|
||||
export NOLACE="/local/experiments/ietf_enhancement_studies/checkpoints/nolace_checkpoint.pth"
|
||||
export TESTMODEL="/local/experiments/ietf_enhancement_studies/opus/dnn/torch/osce/test_model.py"
|
||||
export OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
|
||||
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
from scipy.io import wavfile
|
||||
from pesq import pesq
|
||||
import numpy as np
|
||||
from moc import compare
|
||||
from moc2 import compare as compare2
|
||||
#from warpq import compute_WAPRQ as warpq
|
||||
from lace_loss_metric import compare as laceloss_compare
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='folder with processed items')
|
||||
parser.add_argument('metric', type=str, choices=['pesq', 'moc', 'moc2', 'laceloss'], help='metric to be used for evaluation')
|
||||
|
||||
|
||||
def get_bitrates(folder):
|
||||
with open(os.path.join(folder, 'bitrates.txt')) as f:
|
||||
x = f.read()
|
||||
|
||||
bitrates = [int(y) for y in x.rstrip('\n').split()]
|
||||
|
||||
return bitrates
|
||||
|
||||
def get_itemlist(folder):
|
||||
with open(os.path.join(folder, 'items.txt')) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
items = [x.split()[0] for x in lines]
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def process_item(folder, item, bitrate, metric):
|
||||
fs, x_clean = wavfile.read(os.path.join(folder, 'clean', f"{item}_{bitrate}_clean.wav"))
|
||||
fs, x_opus = wavfile.read(os.path.join(folder, 'opus', f"{item}_{bitrate}_opus.wav"))
|
||||
fs, x_lace = wavfile.read(os.path.join(folder, 'lace', f"{item}_{bitrate}_lace.wav"))
|
||||
fs, x_nolace = wavfile.read(os.path.join(folder, 'nolace', f"{item}_{bitrate}_nolace.wav"))
|
||||
|
||||
x_clean = x_clean.astype(np.float32) / 2**15
|
||||
x_opus = x_opus.astype(np.float32) / 2**15
|
||||
x_lace = x_lace.astype(np.float32) / 2**15
|
||||
x_nolace = x_nolace.astype(np.float32) / 2**15
|
||||
|
||||
if metric == 'pesq':
|
||||
result = [pesq(fs, x_clean, x_opus), pesq(fs, x_clean, x_lace), pesq(fs, x_clean, x_nolace)]
|
||||
elif metric =='moc':
|
||||
result = [compare(x_clean, x_opus), compare(x_clean, x_lace), compare(x_clean, x_nolace)]
|
||||
elif metric =='moc2':
|
||||
result = [compare2(x_clean, x_opus), compare2(x_clean, x_lace), compare2(x_clean, x_nolace)]
|
||||
# elif metric == 'warpq':
|
||||
# result = [warpq(x_clean, x_opus), warpq(x_clean, x_lace), warpq(x_clean, x_nolace)]
|
||||
elif metric == 'laceloss':
|
||||
result = [laceloss_compare(x_clean, x_opus), laceloss_compare(x_clean, x_lace), laceloss_compare(x_clean, x_nolace)]
|
||||
else:
|
||||
raise ValueError(f'unknown metric {metric}')
|
||||
|
||||
return result
|
||||
|
||||
def process_bitrate(folder, items, bitrate, metric):
|
||||
results = np.zeros((len(items), 3))
|
||||
|
||||
for i, item in enumerate(items):
|
||||
results[i, :] = np.array(process_item(folder, item, bitrate, metric))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
items = get_itemlist(args.folder)
|
||||
bitrates = get_bitrates(args.folder)
|
||||
|
||||
results = dict()
|
||||
for br in bitrates:
|
||||
print(f"processing bitrate {br}...")
|
||||
results[br] = process_bitrate(args.folder, items, br, args.metric)
|
||||
|
||||
np.save(os.path.join(args.folder, f'results_{args.metric}.npy'), results)
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
"""STFT-based Loss modules."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
|
||||
|
||||
def get_window(win_name, win_length, *args, **kwargs):
|
||||
window_dict = {
|
||||
'bartlett_window' : torch.bartlett_window,
|
||||
'blackman_window' : torch.blackman_window,
|
||||
'hamming_window' : torch.hamming_window,
|
||||
'hann_window' : torch.hann_window,
|
||||
'kaiser_window' : torch.kaiser_window
|
||||
}
|
||||
|
||||
if not win_name in window_dict:
|
||||
raise ValueError()
|
||||
|
||||
return window_dict[win_name](win_length, *args, **kwargs)
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
"""
|
||||
|
||||
win = get_window(window, win_length).to(x.device)
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
|
||||
|
||||
|
||||
return torch.clamp(torch.abs(x_stft), min=1e-7)
|
||||
|
||||
def spectral_convergence_loss(Y_true, Y_pred):
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
|
||||
|
||||
|
||||
def log_magnitude_loss(Y_true, Y_pred):
|
||||
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
|
||||
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
|
||||
|
||||
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
|
||||
|
||||
def spectral_xcorr_loss(Y_true, Y_pred):
|
||||
Y_true = Y_true.abs()
|
||||
Y_pred = Y_pred.abs()
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return 1 - xcorr.mean()
|
||||
|
||||
|
||||
|
||||
class MRLogMelLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
fs=16000,
|
||||
n_mels=18
|
||||
):
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.mel_specs = []
|
||||
for fft_size in fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
|
||||
n_mels = self.n_mels
|
||||
if fft_size < 128:
|
||||
n_mels //= 2
|
||||
|
||||
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
|
||||
|
||||
for i, mel_spec in enumerate(self.mel_specs):
|
||||
self.add_module(f'mel_spec_{i+1}', mel_spec)
|
||||
|
||||
def forward(self, y_true, y_pred):
|
||||
|
||||
loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for mel_spec in self.mel_specs:
|
||||
Y_true = mel_spec(y_true)
|
||||
Y_pred = mel_spec(y_pred)
|
||||
loss = loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
loss = loss / len(self.mel_specs)
|
||||
|
||||
return loss
|
||||
|
||||
def create_weight_matrix(num_bins, bins_per_band=10):
|
||||
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
|
||||
|
||||
r0 = bins_per_band // 2
|
||||
r1 = bins_per_band - r0
|
||||
|
||||
for i in range(num_bins):
|
||||
i0 = max(i - r0, 0)
|
||||
j0 = min(i + r1, num_bins)
|
||||
|
||||
m[i, i0: j0] += 1
|
||||
|
||||
if i < r0:
|
||||
m[i, :r0 - i] += 1
|
||||
|
||||
if i > num_bins - r1:
|
||||
m[i, num_bins - r1 - i:] += 1
|
||||
|
||||
return m / bins_per_band
|
||||
|
||||
def weighted_spectral_convergence(Y_true, Y_pred, w):
|
||||
|
||||
# calculate sfm based weights
|
||||
logY = torch.log(torch.abs(Y_true) + 1e-9)
|
||||
Y = torch.abs(Y_true)
|
||||
|
||||
avg_logY = torch.matmul(logY.transpose(1, 2), w)
|
||||
avg_Y = torch.matmul(Y.transpose(1, 2), w)
|
||||
|
||||
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
|
||||
|
||||
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
|
||||
|
||||
loss = torch.mean(
|
||||
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
|
||||
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def gen_filterbank(N, Fs=16000):
|
||||
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
|
||||
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
|
||||
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
|
||||
ERB_N = 24.7 + .108*in_freq
|
||||
delta = np.abs(in_freq-out_freq)/ERB_N
|
||||
center = (delta<.5).astype('float32')
|
||||
R = -12*center*delta**2 + (1-center)*(3-12*delta)
|
||||
RE = 10.**(R/10.)
|
||||
norm = np.sum(RE, axis=1)
|
||||
RE = RE/norm[:, np.newaxis]
|
||||
return torch.from_numpy(RE)
|
||||
|
||||
def smooth_log_mag(Y_true, Y_pred, filterbank):
|
||||
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
|
||||
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
|
||||
|
||||
loss = torch.abs(
|
||||
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
|
||||
)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
class MRSTFTLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[2048, 1024, 512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
window='hann_window',
|
||||
fs=16000,
|
||||
log_mag_weight=0,
|
||||
sc_weight=0,
|
||||
wsc_weight=0,
|
||||
smooth_log_mag_weight=2,
|
||||
sxcorr_weight=1):
|
||||
super().__init__()
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.window = window
|
||||
self.log_mag_weight = log_mag_weight
|
||||
self.sc_weight = sc_weight
|
||||
self.wsc_weight = wsc_weight
|
||||
self.smooth_log_mag_weight = smooth_log_mag_weight
|
||||
self.sxcorr_weight = sxcorr_weight
|
||||
self.fs = fs
|
||||
|
||||
# weights for SFM weighted spectral convergence loss
|
||||
self.wsc_weights = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
width = min(11, int(1000 * fft_size / self.fs + .5))
|
||||
width += width % 2
|
||||
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
|
||||
create_weight_matrix(fft_size // 2 + 1, width),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
# filterbanks for smooth log magnitude loss
|
||||
self.filterbanks = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
|
||||
gen_filterbank(fft_size//2),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def __call__(self, y_true, y_pred):
|
||||
|
||||
|
||||
lm_loss = torch.zeros(1, device=y_true.device)
|
||||
sc_loss = torch.zeros(1, device=y_true.device)
|
||||
wsc_loss = torch.zeros(1, device=y_true.device)
|
||||
slm_loss = torch.zeros(1, device=y_true.device)
|
||||
sxcorr_loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for fft_size in self.fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
win_size = fft_size
|
||||
|
||||
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
|
||||
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
|
||||
|
||||
if self.log_mag_weight > 0:
|
||||
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
if self.sc_weight > 0:
|
||||
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
|
||||
|
||||
if self.wsc_weight > 0:
|
||||
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
|
||||
|
||||
if self.smooth_log_mag_weight > 0:
|
||||
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
|
||||
|
||||
if self.sxcorr_weight > 0:
|
||||
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
|
||||
|
||||
|
||||
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
|
||||
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
|
||||
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def td_l2_norm(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class LaceLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.stftloss = MRSTFTLoss(log_mag_weight=0, sc_weight=0, wsc_weight=0, smooth_log_mag_weight=2, sxcorr_weight=1)
|
||||
|
||||
|
||||
def forward(self, x, y):
|
||||
specloss = self.stftloss(x, y)
|
||||
phaseloss = td_l2_norm(x, y)
|
||||
total_loss = (specloss + 10 * phaseloss) / 13
|
||||
|
||||
return total_loss
|
||||
|
||||
def compare(self, x_ref, x_deg):
|
||||
# trim items to same size
|
||||
n = min(len(x_ref), len(x_deg))
|
||||
x_ref = x_ref[:n].copy()
|
||||
x_deg = x_deg[:n].copy()
|
||||
|
||||
# pre-emphasis
|
||||
x_ref[1:] -= 0.85 * x_ref[:-1]
|
||||
x_deg[1:] -= 0.85 * x_deg[:-1]
|
||||
|
||||
device = next(iter(self.parameters())).device
|
||||
|
||||
x = torch.from_numpy(x_ref).to(device)
|
||||
y = torch.from_numpy(x_deg).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
dist = 10 * self.forward(x, y)
|
||||
|
||||
return dist.cpu().numpy().item()
|
||||
|
||||
|
||||
lace_loss = LaceLoss()
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
lace_loss.to(device)
|
||||
|
||||
def compare(x, y):
|
||||
|
||||
return lace_loss.compare(x, y)
|
||||
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
|
||||
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def plot_data(filename, data, title=None):
|
||||
compare_dict = dict()
|
||||
for br in data.keys():
|
||||
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
|
||||
compare_dict[f'LACE {br/1000:.1f} kb/s'] = data[br][:, 1]
|
||||
compare_dict[f'NoLACE {br/1000:.1f} kb/s'] = data[br][:, 2]
|
||||
|
||||
plt.rcParams.update({
|
||||
"text.usetex": True,
|
||||
"font.family": "Helvetica",
|
||||
"font.size": 32
|
||||
})
|
||||
|
||||
black = '#000000'
|
||||
red = '#ff5745'
|
||||
blue = '#007dbc'
|
||||
colors = [black, red, blue]
|
||||
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
|
||||
Patch(facecolor=colors[1], label='LACE'),
|
||||
Patch(facecolor=colors[2], label='NoLACE')]
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(40, 20)
|
||||
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
|
||||
|
||||
for i, patch in enumerate(bplot['boxes']):
|
||||
patch.set_facecolor(colors[i%3])
|
||||
|
||||
ax.set_xticklabels(compare_dict.keys(), rotation=290)
|
||||
|
||||
if title is not None:
|
||||
ax.set_title(title)
|
||||
|
||||
ax.legend(handles=legend_elements)
|
||||
|
||||
fig.savefig(filename, bbox_inches='tight')
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def plot_data(filename, data, title=None):
|
||||
compare_dict = dict()
|
||||
for br in data.keys():
|
||||
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
|
||||
compare_dict[f'LACE (MOC only) {br/1000:.1f} kb/s'] = data[br][:, 1]
|
||||
compare_dict[f'LACE (MOC + TD) {br/1000:.1f} kb/s'] = data[br][:, 2]
|
||||
|
||||
plt.rcParams.update({
|
||||
"text.usetex": True,
|
||||
"font.family": "Helvetica",
|
||||
"font.size": 32
|
||||
})
|
||||
colors = ['pink', 'lightblue', 'lightgreen']
|
||||
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
|
||||
Patch(facecolor=colors[1], label='MOC loss only'),
|
||||
Patch(facecolor=colors[2], label='MOC + TD loss')]
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(40, 20)
|
||||
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
|
||||
|
||||
for i, patch in enumerate(bplot['boxes']):
|
||||
patch.set_facecolor(colors[i%3])
|
||||
|
||||
ax.set_xticklabels(compare_dict.keys(), rotation=290)
|
||||
|
||||
if title is not None:
|
||||
ax.set_title(title)
|
||||
|
||||
ax.legend(handles=legend_elements)
|
||||
|
||||
fig.savefig(filename, bbox_inches='tight')
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
|
||||
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def make_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1]
|
||||
nolace = data[br][:, 2]
|
||||
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
|
||||
def make_diff_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1] - opus
|
||||
nolace = data[br][:, 2] - opus
|
||||
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
|
||||
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def make_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1]
|
||||
nolace = data[br][:, 2]
|
||||
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
|
||||
def make_diff_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1] - opus
|
||||
nolace = data[br][:, 2] - opus
|
||||
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
|
||||
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
|
||||
def compute_vad_mask(x, fs, stop_db=-70):
|
||||
|
||||
frame_length = (fs + 49) // 50
|
||||
x = x[: frame_length * (len(x) // frame_length)]
|
||||
|
||||
frames = x.reshape(-1, frame_length)
|
||||
frame_energy = np.sum(frames ** 2, axis=1)
|
||||
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
|
||||
|
||||
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
|
||||
vactive = np.ones_like(frames)
|
||||
vactive[frame_energy_smooth < max_threshold, :] = 0
|
||||
vactive = vactive.reshape(-1)
|
||||
|
||||
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
|
||||
filter = filter / filter.sum()
|
||||
|
||||
mask = np.convolve(vactive, filter, mode='same')
|
||||
|
||||
return x, mask
|
||||
|
||||
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
|
||||
num_samples = frame_size + (num_frames - 1) * hop_size
|
||||
if len(mask) < num_samples:
|
||||
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
|
||||
else:
|
||||
mask = mask[:num_samples]
|
||||
|
||||
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
|
||||
|
||||
return new_mask
|
||||
|
||||
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
|
||||
num_spectra = (len(x) - window_size - hop_size) // hop_size
|
||||
window = scipy.signal.get_window(window, window_size)
|
||||
N = window_size // 2
|
||||
|
||||
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
|
||||
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
|
||||
|
||||
return psd
|
||||
|
||||
|
||||
def frequency_mask(num_bands, up_factor, down_factor):
|
||||
|
||||
up_mask = np.zeros((num_bands, num_bands))
|
||||
down_mask = np.zeros((num_bands, num_bands))
|
||||
|
||||
for i in range(num_bands):
|
||||
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
|
||||
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
|
||||
|
||||
return down_mask @ up_mask
|
||||
|
||||
|
||||
def rect_fb(band_limits, num_bins=None):
|
||||
num_bands = len(band_limits) - 1
|
||||
if num_bins is None:
|
||||
num_bins = band_limits[-1]
|
||||
|
||||
fb = np.zeros((num_bands, num_bins))
|
||||
for i in range(num_bands):
|
||||
fb[i, band_limits[i]:band_limits[i+1]] = 1
|
||||
|
||||
return fb
|
||||
|
||||
|
||||
def compare(x, y, apply_vad=False):
|
||||
""" Modified version of opus_compare for 16 kHz mono signals
|
||||
|
||||
Args:
|
||||
x (np.ndarray): reference input signal scaled to [-1, 1]
|
||||
y (np.ndarray): test signal scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
float: perceptually weighted error
|
||||
"""
|
||||
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
|
||||
band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
|
||||
num_bands = len(band_limits) - 1
|
||||
fb = rect_fb(band_limits, num_bins=81)
|
||||
|
||||
# trim samples to same size
|
||||
num_samples = min(len(x), len(y))
|
||||
x = x[:num_samples] * 2**15
|
||||
y = y[:num_samples] * 2**15
|
||||
|
||||
psd_x = power_spectrum(x) + 100000
|
||||
psd_y = power_spectrum(y) + 100000
|
||||
|
||||
num_frames = psd_x.shape[0]
|
||||
|
||||
# average band energies
|
||||
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
|
||||
|
||||
# frequecy masking
|
||||
f_mask = frequency_mask(num_bands, 0.1, 0.03)
|
||||
mask_x = be_x @ f_mask.T
|
||||
|
||||
# temporal masking
|
||||
for i in range(1, num_frames):
|
||||
mask_x[i, :] += 0.5 * mask_x[i-1, :]
|
||||
|
||||
# apply mask
|
||||
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
|
||||
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
|
||||
|
||||
# 2-frame average
|
||||
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
|
||||
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
|
||||
|
||||
# distortion metric
|
||||
re = masked_psd_y / masked_psd_x
|
||||
im = np.log(re) ** 2
|
||||
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
|
||||
Ef = np.mean(Eb , axis=1)
|
||||
|
||||
if apply_vad:
|
||||
_, mask = compute_vad_mask(x, 16000)
|
||||
mask = convert_mask(mask, Ef.shape[0])
|
||||
else:
|
||||
mask = np.ones_like(Ef)
|
||||
|
||||
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
|
||||
|
||||
return float(err)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from scipy.io import wavfile
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('ref', type=str, help='reference wav file')
|
||||
parser.add_argument('deg', type=str, help='degraded wav file')
|
||||
parser.add_argument('--apply-vad', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
fs1, x = wavfile.read(args.ref)
|
||||
fs2, y = wavfile.read(args.deg)
|
||||
|
||||
if max(fs1, fs2) != 16000:
|
||||
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
|
||||
|
||||
x = x.astype(np.float32) / 2**15
|
||||
y = y.astype(np.float32) / 2**15
|
||||
|
||||
err = compare(x, y, apply_vad=args.apply_vad)
|
||||
|
||||
print(f"MOC: {err}")
|
||||
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
|
||||
def compute_vad_mask(x, fs, stop_db=-70):
|
||||
|
||||
frame_length = (fs + 49) // 50
|
||||
x = x[: frame_length * (len(x) // frame_length)]
|
||||
|
||||
frames = x.reshape(-1, frame_length)
|
||||
frame_energy = np.sum(frames ** 2, axis=1)
|
||||
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
|
||||
|
||||
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
|
||||
vactive = np.ones_like(frames)
|
||||
vactive[frame_energy_smooth < max_threshold, :] = 0
|
||||
vactive = vactive.reshape(-1)
|
||||
|
||||
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
|
||||
filter = filter / filter.sum()
|
||||
|
||||
mask = np.convolve(vactive, filter, mode='same')
|
||||
|
||||
return x, mask
|
||||
|
||||
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
|
||||
num_samples = frame_size + (num_frames - 1) * hop_size
|
||||
if len(mask) < num_samples:
|
||||
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
|
||||
else:
|
||||
mask = mask[:num_samples]
|
||||
|
||||
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
|
||||
|
||||
return new_mask
|
||||
|
||||
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
|
||||
num_spectra = (len(x) - window_size - hop_size) // hop_size
|
||||
window = scipy.signal.get_window(window, window_size)
|
||||
N = window_size // 2
|
||||
|
||||
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
|
||||
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
|
||||
|
||||
return psd
|
||||
|
||||
|
||||
def frequency_mask(num_bands, up_factor, down_factor):
|
||||
|
||||
up_mask = np.zeros((num_bands, num_bands))
|
||||
down_mask = np.zeros((num_bands, num_bands))
|
||||
|
||||
for i in range(num_bands):
|
||||
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
|
||||
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
|
||||
|
||||
return down_mask @ up_mask
|
||||
|
||||
|
||||
def rect_fb(band_limits, num_bins=None):
|
||||
num_bands = len(band_limits) - 1
|
||||
if num_bins is None:
|
||||
num_bins = band_limits[-1]
|
||||
|
||||
fb = np.zeros((num_bands, num_bins))
|
||||
for i in range(num_bands):
|
||||
fb[i, band_limits[i]:band_limits[i+1]] = 1
|
||||
|
||||
return fb
|
||||
|
||||
|
||||
def _compare(x, y, apply_vad=False, factor=1):
|
||||
""" Modified version of opus_compare for 16 kHz mono signals
|
||||
|
||||
Args:
|
||||
x (np.ndarray): reference input signal scaled to [-1, 1]
|
||||
y (np.ndarray): test signal scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
float: perceptually weighted error
|
||||
"""
|
||||
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
|
||||
band_limits = [factor * b for b in [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]]
|
||||
window_size = factor * 160
|
||||
hop_size = factor * 40
|
||||
num_bins = window_size // 2 + 1
|
||||
num_bands = len(band_limits) - 1
|
||||
fb = rect_fb(band_limits, num_bins=num_bins)
|
||||
|
||||
# trim samples to same size
|
||||
num_samples = min(len(x), len(y))
|
||||
x = x[:num_samples].copy() * 2**15
|
||||
y = y[:num_samples].copy() * 2**15
|
||||
|
||||
psd_x = power_spectrum(x, window_size=window_size, hop_size=hop_size) + 100000
|
||||
psd_y = power_spectrum(y, window_size=window_size, hop_size=hop_size) + 100000
|
||||
|
||||
num_frames = psd_x.shape[0]
|
||||
|
||||
# average band energies
|
||||
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
|
||||
|
||||
# frequecy masking
|
||||
f_mask = frequency_mask(num_bands, 0.1, 0.03)
|
||||
mask_x = be_x @ f_mask.T
|
||||
|
||||
# temporal masking
|
||||
for i in range(1, num_frames):
|
||||
mask_x[i, :] += (0.5 ** factor) * mask_x[i-1, :]
|
||||
|
||||
# apply mask
|
||||
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
|
||||
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
|
||||
|
||||
# 2-frame average
|
||||
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
|
||||
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
|
||||
|
||||
# distortion metric
|
||||
re = masked_psd_y / masked_psd_x
|
||||
#im = re - np.log(re) - 1
|
||||
im = np.log(re) ** 2
|
||||
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
|
||||
Ef = np.mean(Eb ** 1, axis=1)
|
||||
|
||||
if apply_vad:
|
||||
_, mask = compute_vad_mask(x, 16000)
|
||||
mask = convert_mask(mask, Ef.shape[0])
|
||||
else:
|
||||
mask = np.ones_like(Ef)
|
||||
|
||||
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
|
||||
|
||||
return float(err)
|
||||
|
||||
def compare(x, y, apply_vad=False):
|
||||
err = np.linalg.norm([_compare(x, y, apply_vad=apply_vad, factor=1)], ord=2)
|
||||
return err
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from scipy.io import wavfile
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('ref', type=str, help='reference wav file')
|
||||
parser.add_argument('deg', type=str, help='degraded wav file')
|
||||
parser.add_argument('--apply-vad', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
fs1, x = wavfile.read(args.ref)
|
||||
fs2, y = wavfile.read(args.deg)
|
||||
|
||||
if max(fs1, fs2) != 16000:
|
||||
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
|
||||
|
||||
x = x.astype(np.float32) / 2**15
|
||||
y = y.astype(np.float32) / 2**15
|
||||
|
||||
err = compare(x, y, apply_vad=args.apply_vad)
|
||||
|
||||
print(f"MOC: {err}")
|
||||
@@ -0,0 +1,98 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ ! -f "$PYTHON" ]
|
||||
then
|
||||
echo "PYTHON variable does not link to a file. Please point it to your python executable."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$TESTMODEL" ]
|
||||
then
|
||||
echo "TESTMODEL variable does not link to a file. Please point it to your copy of test_model.py"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$OPUSDEMO" ]
|
||||
then
|
||||
echo "OPUSDEMO variable does not link to a file. Please point it to your patched version of opus_demo."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$LACE" ]
|
||||
then
|
||||
echo "LACE variable does not link to a file. Please point it to your copy of the LACE checkpoint."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$NOLACE" ]
|
||||
then
|
||||
echo "LACE variable does not link to a file. Please point it to your copy of the NOLACE checkpoint."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
case $# in
|
||||
2) INPUT=$1; OUTPUT=$2;;
|
||||
*) echo "process_dataset.sh <input folder> <output folder>"; exit 1;;
|
||||
esac
|
||||
|
||||
if [ -d $OUTPUT ]
|
||||
then
|
||||
echo "output folder $OUTPUT exists, aborting..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p $OUTPUT
|
||||
|
||||
if [ "$BITRATES" == "" ]
|
||||
then
|
||||
BITRATES=( 6000 7500 9000 12000 15000 18000 24000 32000 )
|
||||
echo "BITRATES variable not defined. Proceeding with default bitrates ${BITRATES[@]}."
|
||||
fi
|
||||
|
||||
|
||||
echo "LACE=${LACE}" > ${OUTPUT}/info.txt
|
||||
echo "NOLACE=${NOLACE}" >> ${OUTPUT}/info.txt
|
||||
|
||||
ITEMFILE=${OUTPUT}/items.txt
|
||||
BITRATEFILE=${OUTPUT}/bitrates.txt
|
||||
|
||||
FPROCESSING=${OUTPUT}/processing
|
||||
FCLEAN=${OUTPUT}/clean
|
||||
FOPUS=${OUTPUT}/opus
|
||||
FLACE=${OUTPUT}/lace
|
||||
FNOLACE=${OUTPUT}/nolace
|
||||
|
||||
mkdir -p $FPROCESSING $FCLEAN $FOPUS $FLACE $FNOLACE
|
||||
|
||||
echo "${BITRATES[@]}" > $BITRATEFILE
|
||||
|
||||
for fn in $(find $INPUT -type f -name "*.wav")
|
||||
do
|
||||
UUID=$(uuid)
|
||||
echo "$UUID $fn" >> $ITEMFILE
|
||||
PIDS=( )
|
||||
for br in ${BITRATES[@]}
|
||||
do
|
||||
# run opus
|
||||
pfolder=${FPROCESSING}/${UUID}_${br}
|
||||
mkdir -p $pfolder
|
||||
sox $fn -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16
|
||||
(cd ${pfolder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
|
||||
|
||||
# copy clean and opus
|
||||
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16 $FCLEAN/${UUID}_${br}_clean.wav
|
||||
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/noisy.s16 $FOPUS/${UUID}_${br}_opus.wav
|
||||
|
||||
# run LACE
|
||||
$PYTHON $TESTMODEL $pfolder $LACE $FLACE/${UUID}_${br}_lace.wav &
|
||||
PIDS+=( "$!" )
|
||||
|
||||
# run NoLACE
|
||||
$PYTHON $TESTMODEL $pfolder $NOLACE $FNOLACE/${UUID}_${br}_nolace.wav &
|
||||
PIDS+=( "$!" )
|
||||
done
|
||||
for pid in ${PIDS[@]}
|
||||
do
|
||||
wait $pid
|
||||
done
|
||||
done
|
||||
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
import pandas as pd
|
||||
from scipy.spatial.distance import cdist
|
||||
from scipy.io import wavfile
|
||||
import numpy as np
|
||||
|
||||
from nomad_audio.nomad import Nomad
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='folder with processed items')
|
||||
parser.add_argument('--full-reference', action='store_true', help='use NOMAD as full-reference metric')
|
||||
parser.add_argument('--device', type=str, default=None, help='device for Nomad')
|
||||
|
||||
|
||||
def get_bitrates(folder):
|
||||
with open(os.path.join(folder, 'bitrates.txt')) as f:
|
||||
x = f.read()
|
||||
|
||||
bitrates = [int(y) for y in x.rstrip('\n').split()]
|
||||
|
||||
return bitrates
|
||||
|
||||
def get_itemlist(folder):
|
||||
with open(os.path.join(folder, 'items.txt')) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
items = [x.split()[0] for x in lines]
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def nomad_wrapper(ref_folder, deg_folder, full_reference=False, ref_embeddings=None, device=None):
|
||||
model = Nomad(device=device)
|
||||
if not full_reference:
|
||||
results = model.predict(nmr=ref_folder, deg=deg_folder)[0].to_dict()['NOMAD']
|
||||
return results, None
|
||||
else:
|
||||
if ref_embeddings is None:
|
||||
print(f"Computing reference embeddings from {ref_folder}")
|
||||
ref_data = pd.DataFrame(sorted(os.listdir(ref_folder)))
|
||||
ref_data.columns = ['filename']
|
||||
ref_data['filename'] = [os.path.join(ref_folder, x) for x in ref_data['filename']]
|
||||
ref_embeddings = model.get_embeddings_csv(model.model, ref_data).set_index('filename')
|
||||
|
||||
print(f"Computing degraded embeddings from {deg_folder}")
|
||||
deg_data = pd.DataFrame(sorted(os.listdir(deg_folder)))
|
||||
deg_data.columns = ['filename']
|
||||
deg_data['filename'] = [os.path.join(deg_folder, x) for x in deg_data['filename']]
|
||||
deg_embeddings = model.get_embeddings_csv(model.model, deg_data).set_index('filename')
|
||||
|
||||
dist = np.diag(cdist(ref_embeddings, deg_embeddings)) # wasteful
|
||||
test_files = [x.split('/')[-1].split('.')[0] for x in deg_embeddings.index]
|
||||
|
||||
results = dict(zip(test_files, dist))
|
||||
|
||||
return results, ref_embeddings
|
||||
|
||||
|
||||
|
||||
|
||||
def nomad_process_all(folder, full_reference=False, device=None):
|
||||
bitrates = get_bitrates(folder)
|
||||
items = get_itemlist(folder)
|
||||
with tempfile.TemporaryDirectory() as dir:
|
||||
cleandir = os.path.join(dir, 'clean')
|
||||
opusdir = os.path.join(dir, 'opus')
|
||||
lacedir = os.path.join(dir, 'lace')
|
||||
nolacedir = os.path.join(dir, 'nolace')
|
||||
|
||||
# prepare files
|
||||
for d in [cleandir, opusdir, lacedir, nolacedir]: os.makedirs(d)
|
||||
for br in bitrates:
|
||||
for item in items:
|
||||
for cond in ['clean', 'opus', 'lace', 'nolace']:
|
||||
shutil.copyfile(os.path.join(folder, cond, f"{item}_{br}_{cond}.wav"), os.path.join(dir, cond, f"{item}_{br}.wav"))
|
||||
|
||||
nomad_opus, ref_embeddings = nomad_wrapper(cleandir, opusdir, full_reference=full_reference, ref_embeddings=None)
|
||||
nomad_lace, ref_embeddings = nomad_wrapper(cleandir, lacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
|
||||
nomad_nolace, ref_embeddings = nomad_wrapper(cleandir, nolacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
|
||||
|
||||
results = dict()
|
||||
for br in bitrates:
|
||||
results[br] = np.zeros((len(items), 3))
|
||||
for i, item in enumerate(items):
|
||||
key = f"{item}_{br}"
|
||||
results[br][i, 0] = nomad_opus[key]
|
||||
results[br][i, 1] = nomad_lace[key]
|
||||
results[br][i, 2] = nomad_nolace[key]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
items = get_itemlist(args.folder)
|
||||
bitrates = get_bitrates(args.folder)
|
||||
|
||||
results = nomad_process_all(args.folder, full_reference=args.full_reference, device=args.device)
|
||||
|
||||
np.save(os.path.join(args.folder, f'results_nomad.npy'), results)
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,196 @@
|
||||
import os
|
||||
import argparse
|
||||
import yaml
|
||||
import subprocess
|
||||
|
||||
import numpy as np
|
||||
|
||||
from moc2 import compare as moc
|
||||
|
||||
DEBUG=False
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('inputdir', type=str, help='Input folder with test items')
|
||||
parser.add_argument('outputdir', type=str, help='Output folder')
|
||||
parser.add_argument('bitrate', type=int, help='bitrate to test')
|
||||
parser.add_argument('--reference_opus_demo', type=str, default='./opus_demo', help='reference opus_demo binary for generating bitstreams and reference output')
|
||||
parser.add_argument('--encoder_options', type=str, default="", help='encoder options (e.g. -complexity 5)')
|
||||
parser.add_argument('--test_opus_demo', type=str, default='./opus_demo', help='opus_demo binary under test')
|
||||
parser.add_argument('--test_opus_demo_options', type=str, default='-dec_complexity 7', help='options for test opus_demo (e.g. "-dec_complexity 7")')
|
||||
parser.add_argument('--verbose', type=int, default=0, help='verbosity level: 0 for quiet (default), 1 for reporting individual test results, 2 for reporting per-item scores in failed tests')
|
||||
|
||||
def run_opus_encoder(opus_demo_path, input_pcm_path, bitstream_path, application, fs, num_channels, bitrate, options=[], verbose=False):
|
||||
|
||||
call_args = [
|
||||
opus_demo_path,
|
||||
"-e",
|
||||
application,
|
||||
str(fs),
|
||||
str(num_channels),
|
||||
str(bitrate),
|
||||
"-bandwidth",
|
||||
"WB"
|
||||
]
|
||||
|
||||
call_args += options
|
||||
|
||||
call_args += [
|
||||
input_pcm_path,
|
||||
bitstream_path
|
||||
]
|
||||
|
||||
try:
|
||||
if verbose:
|
||||
print(f"running {call_args}...")
|
||||
subprocess.run(call_args)
|
||||
else:
|
||||
subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
except:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def run_opus_decoder(opus_demo_path, bitstream_path, output_pcm_path, fs, num_channels, options=[], verbose=False):
|
||||
|
||||
call_args = [
|
||||
opus_demo_path,
|
||||
"-d",
|
||||
str(fs),
|
||||
str(num_channels)
|
||||
]
|
||||
|
||||
call_args += options
|
||||
|
||||
call_args += [
|
||||
bitstream_path,
|
||||
output_pcm_path
|
||||
]
|
||||
|
||||
try:
|
||||
if verbose:
|
||||
print(f"running {call_args}...")
|
||||
subprocess.run(call_args)
|
||||
else:
|
||||
subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
except:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
def compute_moc_score(reference_pcm, test_pcm, delay=91):
|
||||
x_ref = np.fromfile(reference_pcm, dtype=np.int16).astype(np.float32) / (2 ** 15)
|
||||
x_cut = np.fromfile(test_pcm, dtype=np.int16).astype(np.float32) / (2 ** 15)
|
||||
|
||||
moc_score = moc(x_ref, x_cut[delay:])
|
||||
|
||||
return moc_score
|
||||
|
||||
def sox(*call_args):
|
||||
try:
|
||||
call_args = ["sox"] + list(call_args)
|
||||
subprocess.run(call_args)
|
||||
return 0
|
||||
except:
|
||||
return 1
|
||||
|
||||
def process_clip_factory(ref_opus_demo, test_opus_demo, enc_options, test_options):
|
||||
def process_clip(clip_path, processdir, bitrate):
|
||||
# derive paths
|
||||
clipname = os.path.splitext(os.path.split(clip_path)[1])[0]
|
||||
pcm_path = os.path.join(processdir, clipname + ".raw")
|
||||
bitstream_path = os.path.join(processdir, clipname + ".bin")
|
||||
ref_path = os.path.join(processdir, clipname + "_ref.raw")
|
||||
test_path = os.path.join(processdir, clipname + "_test.raw")
|
||||
|
||||
# run sox
|
||||
sox(clip_path, pcm_path)
|
||||
|
||||
# run encoder
|
||||
run_opus_encoder(ref_opus_demo, pcm_path, bitstream_path, "voip", 16000, 1, bitrate, enc_options)
|
||||
|
||||
# run decoder
|
||||
run_opus_decoder(ref_opus_demo, bitstream_path, ref_path, 16000, 1)
|
||||
run_opus_decoder(test_opus_demo, bitstream_path, test_path, 16000, 1, options=test_options)
|
||||
|
||||
d_ref = compute_moc_score(pcm_path, ref_path)
|
||||
d_test = compute_moc_score(pcm_path, test_path)
|
||||
|
||||
return d_ref, d_test
|
||||
|
||||
|
||||
return process_clip
|
||||
|
||||
def main(inputdir, outputdir, bitrate, reference_opus_demo, test_opus_demo, enc_option_string, test_option_string, verbose):
|
||||
|
||||
# load clips list
|
||||
with open(os.path.join(inputdir, 'clips.yml'), "r") as f:
|
||||
clips = yaml.safe_load(f)
|
||||
|
||||
# parse test options
|
||||
enc_options = enc_option_string.split()
|
||||
test_options = test_option_string.split()
|
||||
|
||||
process_clip = process_clip_factory(reference_opus_demo, test_opus_demo, enc_options, test_options)
|
||||
|
||||
os.makedirs(outputdir, exist_ok=True)
|
||||
processdir = os.path.join(outputdir, 'process')
|
||||
os.makedirs(processdir, exist_ok=True)
|
||||
|
||||
num_passed = 0
|
||||
results = dict()
|
||||
min_rel_diff = 1000
|
||||
min_mean = 1000
|
||||
worst_clip = None
|
||||
worst_lang = None
|
||||
for lang, lang_clips in clips.items():
|
||||
if verbose > 0: print(f"processing language {lang}...")
|
||||
results[lang] = np.zeros((len(lang_clips), 2))
|
||||
for i, clip in enumerate(lang_clips):
|
||||
clip_path = os.path.join(inputdir, clip)
|
||||
d_ref, d_test = process_clip(clip_path, processdir, bitrate)
|
||||
results[lang][i, 0] = d_ref
|
||||
results[lang][i, 1] = d_test
|
||||
|
||||
alpha = 0.5
|
||||
rel_diff = ((results[lang][:, 0] ** alpha - results[lang][:, 1] ** alpha) /(results[lang][:, 0] ** alpha))
|
||||
|
||||
min_idx = np.argmin(rel_diff).item()
|
||||
if rel_diff[min_idx] < min_rel_diff:
|
||||
min_rel_diff = rel_diff[min_idx]
|
||||
worst_clip = lang_clips[min_idx]
|
||||
|
||||
if np.mean(rel_diff) < min_mean:
|
||||
min_mean = np.mean(rel_diff).item()
|
||||
worst_lang = lang
|
||||
|
||||
if np.min(rel_diff) < -0.1 or np.mean(rel_diff) < -0.025:
|
||||
if verbose > 0: print(f"FAIL ({np.mean(results[lang], axis=0)} {np.mean(rel_diff)} {np.min(rel_diff)})")
|
||||
if verbose > 1:
|
||||
for i, c in enumerate(lang_clips):
|
||||
print(f" {c:50s} {results[lang][i]} {rel_diff[i]}")
|
||||
else:
|
||||
if verbose > 0: print(f"PASS ({np.mean(results[lang], axis=0)} {np.mean(rel_diff)} {np.min(rel_diff)})")
|
||||
num_passed += 1
|
||||
|
||||
print(f"{num_passed}/{len(clips)} tests passed!")
|
||||
|
||||
print(f"worst case occured at clip {worst_clip} with relative difference of {min_rel_diff}")
|
||||
print(f"worst mean relative difference was {min_mean} for test {worst_lang}")
|
||||
|
||||
np.save(os.path.join(outputdir, f'results_' + "_".join(test_options) + f"_{bitrate}.npy"), results, allow_pickle=True)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.inputdir,
|
||||
args.outputdir,
|
||||
args.bitrate,
|
||||
args.reference_opus_demo,
|
||||
args.test_opus_demo,
|
||||
args.encoder_options,
|
||||
args.test_opus_demo_options,
|
||||
args.verbose)
|
||||
@@ -0,0 +1,205 @@
|
||||
""" module for inspecting models during inference """
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation as animation
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
|
||||
_state = dict()
|
||||
_folder = 'endoscopy'
|
||||
|
||||
def get_gru_gates(gru, input, state):
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
|
||||
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
|
||||
|
||||
# reset gate
|
||||
start, stop = 0 * hidden_size, 1 * hidden_size
|
||||
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# update gate
|
||||
start, stop = 1 * hidden_size, 2 * hidden_size
|
||||
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# new gate
|
||||
start, stop = 2 * hidden_size, 3 * hidden_size
|
||||
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
|
||||
|
||||
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
|
||||
|
||||
|
||||
def init(folder='endoscopy'):
|
||||
""" sets up output folder for endoscopy data """
|
||||
|
||||
global _folder
|
||||
_folder = folder
|
||||
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
else:
|
||||
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
|
||||
|
||||
def write_data(key, data, fs):
|
||||
""" appends data to previous data written under key """
|
||||
|
||||
global _state
|
||||
|
||||
# convert to numpy if torch.Tensor is given
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.detach().numpy()
|
||||
|
||||
if not key in _state:
|
||||
_state[key] = {
|
||||
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
|
||||
'fs' : fs,
|
||||
'dim' : tuple(data.shape),
|
||||
'dtype' : str(data.dtype)
|
||||
}
|
||||
|
||||
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
|
||||
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
|
||||
else:
|
||||
if _state[key]['fs'] != fs:
|
||||
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
|
||||
if _state[key]['dtype'] != str(data.dtype):
|
||||
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
|
||||
if _state[key]['dim'] != tuple(data.shape):
|
||||
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
|
||||
|
||||
_state[key]['fid'].write(data.tobytes())
|
||||
|
||||
def close(folder='endoscopy'):
|
||||
""" clean up """
|
||||
for key in _state.keys():
|
||||
_state[key]['fid'].close()
|
||||
|
||||
|
||||
def read_data(folder='endoscopy'):
|
||||
""" retrieves written data as numpy arrays """
|
||||
|
||||
|
||||
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
|
||||
|
||||
return_dict = dict()
|
||||
|
||||
for key in keys:
|
||||
with open(os.path.join(folder, key + '.yml'), 'r') as f:
|
||||
value = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
|
||||
data = np.frombuffer(f.read(), dtype=value['dtype'])
|
||||
|
||||
value['data'] = data.reshape((-1,) + value['dim'])
|
||||
|
||||
return_dict[key] = value
|
||||
|
||||
return return_dict
|
||||
|
||||
def get_best_reshape(shape, target_ratio=1):
|
||||
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
|
||||
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return (1,)
|
||||
|
||||
num_columns = int((pixel_count / target_ratio)**.5)
|
||||
|
||||
while (pixel_count % num_columns):
|
||||
num_columns -= 1
|
||||
|
||||
num_rows = pixel_count // num_columns
|
||||
|
||||
return (num_rows, num_columns)
|
||||
|
||||
def get_type_and_shape(shape):
|
||||
|
||||
# can happen if data is one dimensional
|
||||
if len(shape) == 0:
|
||||
shape = (1,)
|
||||
|
||||
# calculate pixel count
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return 'plot', (1, )
|
||||
|
||||
# stay with shape if already 2-dimensional
|
||||
if len(shape) == 2:
|
||||
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
|
||||
return 'image', shape
|
||||
|
||||
return 'image', get_best_reshape(shape)
|
||||
|
||||
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
|
||||
|
||||
# determine plot setup
|
||||
num_keys = len(data.keys())
|
||||
|
||||
num_rows = int((num_keys * 3/4) ** .5)
|
||||
|
||||
num_cols = (num_keys + num_rows - 1) // num_rows
|
||||
|
||||
fig, axs = plt.subplots(num_rows, num_cols)
|
||||
fig.set_size_inches(num_cols * 5, num_rows * 5)
|
||||
|
||||
display = dict()
|
||||
|
||||
fs_max = max([val['fs'] for val in data.values()])
|
||||
|
||||
num_samples = max([val['data'].shape[0] for val in data.values()])
|
||||
|
||||
keys = sorted(data.keys())
|
||||
|
||||
# inspect data
|
||||
for i, key in enumerate(keys):
|
||||
axs[i // num_cols, i % num_cols].title.set_text(key)
|
||||
|
||||
display[key] = dict()
|
||||
|
||||
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
|
||||
display[key]['down_factor'] = data[key]['fs'] / fs_max
|
||||
|
||||
start_index = max(start_index, half_signal_window_length)
|
||||
while stop_index < 0:
|
||||
stop_index += num_samples
|
||||
|
||||
stop_index = min(stop_index, num_samples - half_signal_window_length)
|
||||
|
||||
# actual plotting
|
||||
frames = []
|
||||
for index in range(start_index, stop_index):
|
||||
ims = []
|
||||
for i, key in enumerate(keys):
|
||||
feature_index = int(round(index * display[key]['down_factor']))
|
||||
|
||||
if display[key]['type'] == 'plot':
|
||||
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
|
||||
|
||||
elif display[key]['type'] == 'image':
|
||||
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
|
||||
|
||||
frames.append(ims)
|
||||
|
||||
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
|
||||
|
||||
if not filename.endswith('.mp4'):
|
||||
filename += '.mp4'
|
||||
|
||||
ani.save(filename)
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,25 @@
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation
|
||||
|
||||
def make_playback_animation(savepath, spec, duration_ms, vmin=20, vmax=90):
|
||||
fig, axs = plt.subplots()
|
||||
axs.set_axis_off()
|
||||
fig.set_size_inches((duration_ms / 1000 * 5, 5))
|
||||
frames = []
|
||||
frame_duration=20
|
||||
num_frames = int(duration_ms / frame_duration + .99)
|
||||
|
||||
spec_height, spec_width = spec.shape
|
||||
for i in range(num_frames):
|
||||
xpos = (i - 1) / (num_frames - 3) * (spec_width - 1)
|
||||
new_frame = axs.imshow(spec, cmap='inferno', origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
|
||||
if i in {0, num_frames - 1}:
|
||||
frames.append([new_frame])
|
||||
else:
|
||||
line = axs.plot([xpos, xpos], [0, spec_height-1], color='white', alpha=0.8)[0]
|
||||
frames.append([new_frame, line])
|
||||
|
||||
|
||||
ani = matplotlib.animation.ArtistAnimation(fig, frames, blit=True, interval=frame_duration)
|
||||
ani.save(savepath, dpi=720)
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
96
managed_components/78__esp-opus/dnn/torch/osce/test_model.py
Normal file
96
managed_components/78__esp-opus/dnn/torch/osce/test_model.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from scipy.io import wavfile
|
||||
|
||||
|
||||
from models import model_dict
|
||||
from utils.silk_features import load_inference_data
|
||||
from utils import endoscopy
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
args = type('dummy', (object,),
|
||||
{
|
||||
'input' : 'testitems/all_0_orig.se',
|
||||
'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth',
|
||||
'output' : 'out.wav',
|
||||
})()
|
||||
else:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('input', type=str, help='path to folder with features and signals')
|
||||
parser.add_argument('checkpoint', type=str, help='checkpoint file')
|
||||
parser.add_argument('output', type=str, help='output file')
|
||||
parser.add_argument('--debug', action='store_true', help='enables debug output')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(2)
|
||||
|
||||
input_folder = args.input
|
||||
checkpoint_file = args.checkpoint
|
||||
|
||||
|
||||
output_file = args.output
|
||||
if not output_file.endswith('.wav'):
|
||||
output_file += '.wav'
|
||||
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
||||
|
||||
# check model
|
||||
if not 'name' in checkpoint['setup']['model']:
|
||||
print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
|
||||
model_name = 'pitchpostfilter'
|
||||
else:
|
||||
model_name = checkpoint['setup']['model']['name']
|
||||
|
||||
model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
|
||||
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# generate model input
|
||||
setup = checkpoint['setup']
|
||||
signal, features, periods, numbits = load_inference_data(input_folder, **setup['data'])
|
||||
|
||||
if args.debug:
|
||||
endoscopy.init()
|
||||
|
||||
output = model.process(signal, features, periods, numbits, debug=args.debug)
|
||||
|
||||
wavfile.write(output_file, 16000, output.cpu().numpy())
|
||||
|
||||
if args.debug:
|
||||
endoscopy.close()
|
||||
103
managed_components/78__esp-opus/dnn/torch/osce/test_vocoder.py
Normal file
103
managed_components/78__esp-opus/dnn/torch/osce/test_vocoder.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from scipy.io import wavfile
|
||||
|
||||
from time import time
|
||||
|
||||
|
||||
from models import model_dict
|
||||
from utils.lpcnet_features import load_lpcnet_features
|
||||
from utils import endoscopy
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
args = type('dummy', (object,),
|
||||
{
|
||||
'input' : 'testitems/all_0_orig.se',
|
||||
'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth',
|
||||
'output' : 'out.wav',
|
||||
})()
|
||||
else:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('input', type=str, help='path to input features')
|
||||
parser.add_argument('checkpoint', type=str, help='checkpoint file')
|
||||
parser.add_argument('output', type=str, help='output file')
|
||||
parser.add_argument('--debug', action='store_true', help='enables debug output')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(2)
|
||||
|
||||
input_folder = args.input
|
||||
checkpoint_file = args.checkpoint
|
||||
|
||||
|
||||
output_file = args.output
|
||||
if not output_file.endswith('.wav'):
|
||||
output_file += '.wav'
|
||||
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
||||
|
||||
# check model
|
||||
if not 'name' in checkpoint['setup']['model']:
|
||||
print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
|
||||
model_name = 'pitchpostfilter'
|
||||
else:
|
||||
model_name = checkpoint['setup']['model']['name']
|
||||
|
||||
model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
|
||||
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# generate model input
|
||||
setup = checkpoint['setup']
|
||||
testdata = load_lpcnet_features(input_folder)
|
||||
features = testdata['features']
|
||||
periods = testdata['periods']
|
||||
|
||||
if args.debug:
|
||||
endoscopy.init()
|
||||
|
||||
start = time()
|
||||
output = model.process(features, periods, debug=args.debug)
|
||||
elapsed = time() - start
|
||||
print(f"[timing] inference took {elapsed * 1000} ms")
|
||||
|
||||
wavfile.write(output_file, 16000, output.cpu().numpy())
|
||||
|
||||
if args.debug:
|
||||
endoscopy.close()
|
||||
307
managed_components/78__esp-opus/dnn/torch/osce/train_model.py
Normal file
307
managed_components/78__esp-opus/dnn/torch/osce/train_model.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
seed=1888
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
import random
|
||||
random.seed(seed)
|
||||
|
||||
import yaml
|
||||
|
||||
try:
|
||||
import git
|
||||
has_git = True
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
import torch
|
||||
torch.manual_seed(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
import numpy as np
|
||||
np.random.seed(seed)
|
||||
|
||||
from scipy.io import wavfile
|
||||
|
||||
import pesq
|
||||
|
||||
from data import SilkEnhancementSet
|
||||
from models import model_dict
|
||||
from engine.engine import train_one_epoch, evaluate
|
||||
|
||||
|
||||
from utils.silk_features import load_inference_data
|
||||
from utils.misc import count_parameters, count_nonzero_parameters
|
||||
|
||||
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||
parser.add_argument('output', type=str, help='output path')
|
||||
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||
parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
|
||||
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
checkpoint_prefix = 'checkpoint'
|
||||
output_prefix = 'output'
|
||||
setup_name = 'setup.yml'
|
||||
output_file='out.txt'
|
||||
|
||||
|
||||
# check model
|
||||
if not 'name' in setup['model']:
|
||||
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
|
||||
model_name = 'pitchpostfilter'
|
||||
else:
|
||||
model_name = setup['model']['name']
|
||||
|
||||
# prepare output folder
|
||||
if os.path.exists(args.output):
|
||||
print("warning: output folder exists")
|
||||
|
||||
reply = input('continue? (y/n): ')
|
||||
while reply not in {'y', 'n'}:
|
||||
reply = input('continue? (y/n): ')
|
||||
|
||||
if reply == 'n':
|
||||
os._exit(0)
|
||||
else:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# add repo info to setup
|
||||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir, search_parent_directories=True)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
is_dirty = repo.is_dirty()
|
||||
|
||||
if is_dirty:
|
||||
print("warning: repo is dirty")
|
||||
with open(os.path.join(args.output, 'repo.diff'), "w") as f:
|
||||
f.write(repo.git.execute(["git", "diff"]))
|
||||
|
||||
setup['repo']['hash'] = hash
|
||||
setup['repo']['urls'] = urls
|
||||
setup['repo']['dirty'] = is_dirty
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
# dump setup
|
||||
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||
yaml.dump(setup, f)
|
||||
|
||||
ref = None
|
||||
if args.testdata is not None:
|
||||
|
||||
testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
|
||||
|
||||
inference_test = True
|
||||
inference_folder = os.path.join(args.output, 'inference_test')
|
||||
os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
|
||||
|
||||
try:
|
||||
ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
inference_test = False
|
||||
|
||||
# training parameters
|
||||
batch_size = setup['training']['batch_size']
|
||||
epochs = setup['training']['epochs']
|
||||
lr = setup['training']['lr']
|
||||
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||
|
||||
# load training dataset
|
||||
data_config = setup['data']
|
||||
data = SilkEnhancementSet(setup['dataset'], **data_config)
|
||||
|
||||
# load validation dataset if given
|
||||
if 'validation_dataset' in setup:
|
||||
validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
|
||||
|
||||
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
|
||||
|
||||
run_validation = True
|
||||
else:
|
||||
run_validation = False
|
||||
|
||||
# create model
|
||||
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
|
||||
|
||||
if args.initial_checkpoint is not None:
|
||||
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||
model.load_state_dict(chkpt['state_dict'])
|
||||
|
||||
# set compute device
|
||||
if type(args.device) == type(None):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
# push model to device
|
||||
model.to(device)
|
||||
|
||||
# dataloader
|
||||
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
|
||||
|
||||
# optimizer is introduced to trainable parameters
|
||||
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.Adam(parameters, lr=lr)
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||
|
||||
# loss
|
||||
w_l1 = setup['training']['loss']['w_l1']
|
||||
w_lm = setup['training']['loss']['w_lm']
|
||||
w_slm = setup['training']['loss']['w_slm']
|
||||
w_sc = setup['training']['loss']['w_sc']
|
||||
w_logmel = setup['training']['loss']['w_logmel']
|
||||
w_wsc = setup['training']['loss']['w_wsc']
|
||||
w_xcorr = setup['training']['loss']['w_xcorr']
|
||||
w_sxcorr = setup['training']['loss']['w_sxcorr']
|
||||
w_l2 = setup['training']['loss']['w_l2']
|
||||
|
||||
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
|
||||
|
||||
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
|
||||
logmelloss = MRLogMelLoss().to(device)
|
||||
|
||||
def xcorr_loss(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return torch.mean(loss)
|
||||
|
||||
def td_l2_norm(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def td_l1(y_true, y_pred, pow=0):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
|
||||
|
||||
return torch.mean(tmp)
|
||||
|
||||
def criterion(x, y):
|
||||
|
||||
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
|
||||
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
|
||||
|
||||
|
||||
|
||||
# model checkpoint
|
||||
checkpoint = {
|
||||
'setup' : setup,
|
||||
'state_dict' : model.state_dict(),
|
||||
'loss' : -1
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
if not args.no_redirect:
|
||||
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||
|
||||
print("summary:")
|
||||
|
||||
print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
|
||||
if hasattr(model, 'flop_count'):
|
||||
print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
|
||||
|
||||
if ref is not None:
|
||||
noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
|
||||
initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
|
||||
print(f"initial MOS (PESQ): {initial_mos}")
|
||||
|
||||
best_loss = 1e9
|
||||
|
||||
for ep in range(1, epochs + 1):
|
||||
print(f"training epoch {ep}...")
|
||||
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
|
||||
|
||||
|
||||
# save checkpoint
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['loss'] = new_loss
|
||||
|
||||
if run_validation:
|
||||
print("running validation...")
|
||||
validation_loss = evaluate(model, criterion, validation_dataloader, device)
|
||||
checkpoint['validation_loss'] = validation_loss
|
||||
|
||||
if validation_loss < best_loss:
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
|
||||
best_loss = validation_loss
|
||||
|
||||
if inference_test:
|
||||
print("running inference test...")
|
||||
out = model.process(testsignal, features, periods, numbits).cpu().numpy()
|
||||
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
|
||||
if ref is not None:
|
||||
mos = pesq.pesq(16000, ref, out, mode='wb')
|
||||
print(f"MOS (PESQ): {mos}")
|
||||
|
||||
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||
|
||||
|
||||
print(f"non-zero parameters: {count_nonzero_parameters(model)}\n")
|
||||
|
||||
print('Done')
|
||||
287
managed_components/78__esp-opus/dnn/torch/osce/train_vocoder.py
Normal file
287
managed_components/78__esp-opus/dnn/torch/osce/train_vocoder.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import yaml
|
||||
|
||||
try:
|
||||
import git
|
||||
has_git = True
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from scipy.io import wavfile
|
||||
|
||||
import pesq
|
||||
|
||||
from data import LPCNetVocodingDataset
|
||||
from models import model_dict
|
||||
from engine.vocoder_engine import train_one_epoch, evaluate
|
||||
|
||||
|
||||
from utils.lpcnet_features import load_lpcnet_features
|
||||
from utils.misc import count_parameters
|
||||
|
||||
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('setup', type=str, help='setup yaml file')
|
||||
parser.add_argument('output', type=str, help='output path')
|
||||
parser.add_argument('--device', type=str, help='compute device', default=None)
|
||||
parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
|
||||
parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
|
||||
parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
with open(args.setup, 'r') as f:
|
||||
setup = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
checkpoint_prefix = 'checkpoint'
|
||||
output_prefix = 'output'
|
||||
setup_name = 'setup.yml'
|
||||
output_file='out.txt'
|
||||
|
||||
|
||||
# check model
|
||||
if not 'name' in setup['model']:
|
||||
print(f'warning: did not find model entry in setup, using default PitchPostFilter')
|
||||
model_name = 'pitchpostfilter'
|
||||
else:
|
||||
model_name = setup['model']['name']
|
||||
|
||||
# prepare output folder
|
||||
if os.path.exists(args.output):
|
||||
print("warning: output folder exists")
|
||||
|
||||
reply = input('continue? (y/n): ')
|
||||
while reply not in {'y', 'n'}:
|
||||
reply = input('continue? (y/n): ')
|
||||
|
||||
if reply == 'n':
|
||||
os._exit()
|
||||
else:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# add repo info to setup
|
||||
if has_git:
|
||||
working_dir = os.path.split(__file__)[0]
|
||||
try:
|
||||
repo = git.Repo(working_dir, search_parent_directories=True)
|
||||
setup['repo'] = dict()
|
||||
hash = repo.head.object.hexsha
|
||||
urls = list(repo.remote().urls)
|
||||
is_dirty = repo.is_dirty()
|
||||
|
||||
if is_dirty:
|
||||
print("warning: repo is dirty")
|
||||
|
||||
setup['repo']['hash'] = hash
|
||||
setup['repo']['urls'] = urls
|
||||
setup['repo']['dirty'] = is_dirty
|
||||
except:
|
||||
has_git = False
|
||||
|
||||
# dump setup
|
||||
with open(os.path.join(args.output, setup_name), 'w') as f:
|
||||
yaml.dump(setup, f)
|
||||
|
||||
ref = None
|
||||
# prepare inference test if wanted
|
||||
inference_test = False
|
||||
if type(args.test_features) != type(None):
|
||||
test_features = load_lpcnet_features(args.test_features)
|
||||
features = test_features['features']
|
||||
periods = test_features['periods']
|
||||
inference_folder = os.path.join(args.output, 'inference_test')
|
||||
os.makedirs(inference_folder, exist_ok=True)
|
||||
inference_test = True
|
||||
|
||||
|
||||
# training parameters
|
||||
batch_size = setup['training']['batch_size']
|
||||
epochs = setup['training']['epochs']
|
||||
lr = setup['training']['lr']
|
||||
lr_decay_factor = setup['training']['lr_decay_factor']
|
||||
|
||||
# load training dataset
|
||||
data_config = setup['data']
|
||||
data = LPCNetVocodingDataset(setup['dataset'], **data_config)
|
||||
|
||||
# load validation dataset if given
|
||||
if 'validation_dataset' in setup:
|
||||
validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
|
||||
|
||||
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
|
||||
|
||||
run_validation = True
|
||||
else:
|
||||
run_validation = False
|
||||
|
||||
# create model
|
||||
model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
|
||||
|
||||
if args.initial_checkpoint is not None:
|
||||
print(f"loading state dict from {args.initial_checkpoint}...")
|
||||
chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||
model.load_state_dict(chkpt['state_dict'])
|
||||
|
||||
# set compute device
|
||||
if type(args.device) == type(None):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
# push model to device
|
||||
model.to(device)
|
||||
|
||||
# dataloader
|
||||
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
|
||||
|
||||
# optimizer is introduced to trainable parameters
|
||||
parameters = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.Adam(parameters, lr=lr)
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||
|
||||
# loss
|
||||
w_l1 = setup['training']['loss']['w_l1']
|
||||
w_lm = setup['training']['loss']['w_lm']
|
||||
w_slm = setup['training']['loss']['w_slm']
|
||||
w_sc = setup['training']['loss']['w_sc']
|
||||
w_logmel = setup['training']['loss']['w_logmel']
|
||||
w_wsc = setup['training']['loss']['w_wsc']
|
||||
w_xcorr = setup['training']['loss']['w_xcorr']
|
||||
w_sxcorr = setup['training']['loss']['w_sxcorr']
|
||||
w_l2 = setup['training']['loss']['w_l2']
|
||||
|
||||
w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
|
||||
|
||||
stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
|
||||
logmelloss = MRLogMelLoss().to(device)
|
||||
|
||||
def xcorr_loss(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return torch.mean(loss)
|
||||
|
||||
def td_l2_norm(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def td_l1(y_true, y_pred, pow=0):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
|
||||
|
||||
return torch.mean(tmp)
|
||||
|
||||
def criterion(x, y):
|
||||
|
||||
return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y)
|
||||
+ w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
|
||||
|
||||
|
||||
|
||||
# model checkpoint
|
||||
checkpoint = {
|
||||
'setup' : setup,
|
||||
'state_dict' : model.state_dict(),
|
||||
'loss' : -1
|
||||
}
|
||||
|
||||
|
||||
if not args.no_redirect:
|
||||
print(f"re-directing output to {os.path.join(args.output, output_file)}")
|
||||
sys.stdout = open(os.path.join(args.output, output_file), "w")
|
||||
|
||||
print("summary:")
|
||||
|
||||
print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
|
||||
if hasattr(model, 'flop_count'):
|
||||
print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
|
||||
|
||||
if ref is not None:
|
||||
pass
|
||||
|
||||
best_loss = 1e9
|
||||
|
||||
for ep in range(1, epochs + 1):
|
||||
print(f"training epoch {ep}...")
|
||||
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
|
||||
|
||||
|
||||
# save checkpoint
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['loss'] = new_loss
|
||||
|
||||
if run_validation:
|
||||
print("running validation...")
|
||||
validation_loss = evaluate(model, criterion, validation_dataloader, device)
|
||||
checkpoint['validation_loss'] = validation_loss
|
||||
|
||||
if validation_loss < best_loss:
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
|
||||
best_loss = validation_loss
|
||||
|
||||
if inference_test:
|
||||
print("running inference test...")
|
||||
out = model.process(features, periods).cpu().numpy()
|
||||
wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
|
||||
if ref is not None:
|
||||
mos = pesq.pesq(16000, ref, out, mode='wb')
|
||||
print(f"MOS (PESQ): {mos}")
|
||||
|
||||
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
|
||||
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
|
||||
|
||||
|
||||
print()
|
||||
|
||||
print('Done')
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jean-Marc Valin */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# x is (batch, nb_in_channels, nb_frames*frame_size)
|
||||
# kernels is (batch, nb_out_channels, nb_in_channels, nb_frames, coeffs)
|
||||
def adaconv_kernel(x, kernels, half_window, fft_size=256):
|
||||
device=x.device
|
||||
overlap_size=half_window.size(-1)
|
||||
nb_frames=kernels.size(3)
|
||||
nb_batches=kernels.size(0)
|
||||
nb_out_channels=kernels.size(1)
|
||||
nb_in_channels=kernels.size(2)
|
||||
kernel_size = kernels.size(-1)
|
||||
x = x.reshape(nb_batches, 1, nb_in_channels, nb_frames, -1)
|
||||
frame_size = x.size(-1)
|
||||
# build window: [zeros, rising window, ones, falling window, zeros]
|
||||
window = torch.cat(
|
||||
[
|
||||
torch.zeros(frame_size, device=device),
|
||||
half_window,
|
||||
torch.ones(frame_size - overlap_size, device=device),
|
||||
1 - half_window,
|
||||
torch.zeros(fft_size - 2 * frame_size - overlap_size,device=device)
|
||||
])
|
||||
x_prev = torch.cat([torch.zeros_like(x[:, :, :, :1, :]), x[:, :, :, :-1, :]], dim=-2)
|
||||
x_next = torch.cat([x[:, :, :, 1:, :overlap_size], torch.zeros_like(x[:, :, :, -1:, :overlap_size])], dim=-2)
|
||||
x_padded = torch.cat([x_prev, x, x_next, torch.zeros(nb_batches, 1, nb_in_channels, nb_frames, fft_size - 2 * frame_size - overlap_size, device=device)], -1)
|
||||
k_padded = torch.cat([torch.flip(kernels, [-1]), torch.zeros(nb_batches, nb_out_channels, nb_in_channels, nb_frames, fft_size-kernel_size, device=device)], dim=-1)
|
||||
|
||||
# compute convolution
|
||||
X = torch.fft.rfft(x_padded, dim=-1)
|
||||
K = torch.fft.rfft(k_padded, dim=-1)
|
||||
|
||||
out = torch.fft.irfft(X * K, dim=-1)
|
||||
# combine in channels
|
||||
out = torch.sum(out, dim=2)
|
||||
# apply the cross-fading
|
||||
out = window.reshape(1, 1, 1, -1)*out
|
||||
crossfaded = out[:,:,:,frame_size:2*frame_size] + torch.cat([torch.zeros(nb_batches, nb_out_channels, 1, frame_size, device=device), out[:, :, :-1, 2*frame_size:3*frame_size]], dim=-2)
|
||||
|
||||
return crossfaded.reshape(nb_batches, nb_out_channels, -1)
|
||||
@@ -0,0 +1,8 @@
|
||||
|
||||
|
||||
def _conv1d_flop_count(layer, rate):
|
||||
return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0]
|
||||
|
||||
|
||||
def _dense_flop_count(layer, rate):
|
||||
return 2 * ((layer.in_features + 1) * layer.out_features * rate )
|
||||
@@ -0,0 +1,205 @@
|
||||
""" module for inspecting models during inference """
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation as animation
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
|
||||
_state = dict()
|
||||
_folder = 'endoscopy'
|
||||
|
||||
def get_gru_gates(gru, input, state):
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
|
||||
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
|
||||
|
||||
# reset gate
|
||||
start, stop = 0 * hidden_size, 1 * hidden_size
|
||||
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# update gate
|
||||
start, stop = 1 * hidden_size, 2 * hidden_size
|
||||
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# new gate
|
||||
start, stop = 2 * hidden_size, 3 * hidden_size
|
||||
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
|
||||
|
||||
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
|
||||
|
||||
|
||||
def init(folder='endoscopy'):
|
||||
""" sets up output folder for endoscopy data """
|
||||
|
||||
global _folder
|
||||
_folder = folder
|
||||
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
else:
|
||||
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
|
||||
|
||||
def write_data(key, data, fs):
|
||||
""" appends data to previous data written under key """
|
||||
|
||||
global _state
|
||||
|
||||
# convert to numpy if torch.Tensor is given
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.detach().numpy()
|
||||
|
||||
if not key in _state:
|
||||
_state[key] = {
|
||||
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
|
||||
'fs' : fs,
|
||||
'dim' : tuple(data.shape),
|
||||
'dtype' : str(data.dtype)
|
||||
}
|
||||
|
||||
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
|
||||
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
|
||||
else:
|
||||
if _state[key]['fs'] != fs:
|
||||
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
|
||||
if _state[key]['dtype'] != str(data.dtype):
|
||||
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
|
||||
if _state[key]['dim'] != tuple(data.shape):
|
||||
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
|
||||
|
||||
_state[key]['fid'].write(data.tobytes())
|
||||
|
||||
def close(folder='endoscopy'):
|
||||
""" clean up """
|
||||
for key in _state.keys():
|
||||
_state[key]['fid'].close()
|
||||
|
||||
|
||||
def read_data(folder='endoscopy'):
|
||||
""" retrieves written data as numpy arrays """
|
||||
|
||||
|
||||
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
|
||||
|
||||
return_dict = dict()
|
||||
|
||||
for key in keys:
|
||||
with open(os.path.join(folder, key + '.yml'), 'r') as f:
|
||||
value = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
|
||||
data = np.frombuffer(f.read(), dtype=value['dtype'])
|
||||
|
||||
value['data'] = data.reshape((-1,) + value['dim'])
|
||||
|
||||
return_dict[key] = value
|
||||
|
||||
return return_dict
|
||||
|
||||
def get_best_reshape(shape, target_ratio=1):
|
||||
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
|
||||
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return (1,)
|
||||
|
||||
num_columns = int((pixel_count / target_ratio)**.5)
|
||||
|
||||
while (pixel_count % num_columns):
|
||||
num_columns -= 1
|
||||
|
||||
num_rows = pixel_count // num_columns
|
||||
|
||||
return (num_rows, num_columns)
|
||||
|
||||
def get_type_and_shape(shape):
|
||||
|
||||
# can happen if data is one dimensional
|
||||
if len(shape) == 0:
|
||||
shape = (1,)
|
||||
|
||||
# calculate pixel count
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return 'plot', (1, )
|
||||
|
||||
# stay with shape if already 2-dimensional
|
||||
if len(shape) == 2:
|
||||
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
|
||||
return 'image', shape
|
||||
|
||||
return 'image', get_best_reshape(shape)
|
||||
|
||||
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
|
||||
|
||||
# determine plot setup
|
||||
num_keys = len(data.keys())
|
||||
|
||||
num_rows = int((num_keys * 3/4) ** .5)
|
||||
|
||||
num_cols = (num_keys + num_rows - 1) // num_rows
|
||||
|
||||
fig, axs = plt.subplots(num_rows, num_cols)
|
||||
fig.set_size_inches(num_cols * 5, num_rows * 5)
|
||||
|
||||
display = dict()
|
||||
|
||||
fs_max = max([val['fs'] for val in data.values()])
|
||||
|
||||
num_samples = max([val['data'].shape[0] for val in data.values()])
|
||||
|
||||
keys = sorted(data.keys())
|
||||
|
||||
# inspect data
|
||||
for i, key in enumerate(keys):
|
||||
axs[i // num_cols, i % num_cols].title.set_text(key)
|
||||
|
||||
display[key] = dict()
|
||||
|
||||
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
|
||||
display[key]['down_factor'] = data[key]['fs'] / fs_max
|
||||
|
||||
start_index = max(start_index, half_signal_window_length)
|
||||
while stop_index < 0:
|
||||
stop_index += num_samples
|
||||
|
||||
stop_index = min(stop_index, num_samples - half_signal_window_length)
|
||||
|
||||
# actual plotting
|
||||
frames = []
|
||||
for index in range(start_index, stop_index):
|
||||
ims = []
|
||||
for i, key in enumerate(keys):
|
||||
feature_index = int(round(index * display[key]['down_factor']))
|
||||
|
||||
if display[key]['type'] == 'plot':
|
||||
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
|
||||
|
||||
elif display[key]['type'] == 'image':
|
||||
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
|
||||
|
||||
frames.append(ims)
|
||||
|
||||
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
|
||||
|
||||
if not filename.endswith('.mp4'):
|
||||
filename += '.mp4'
|
||||
|
||||
ani.save(filename)
|
||||
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FIR(nn.Module):
|
||||
def __init__(self, numtaps, bands, desired, fs=2):
|
||||
super().__init__()
|
||||
|
||||
if numtaps % 2 == 0:
|
||||
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
|
||||
numtaps += 1
|
||||
|
||||
a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
|
||||
|
||||
self.weight = torch.from_numpy(a.astype(np.float32))
|
||||
|
||||
def forward(self, x):
|
||||
num_channels = x.size(1)
|
||||
|
||||
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
|
||||
|
||||
y = F.conv1d(x, weight, groups=num_channels)
|
||||
|
||||
return y
|
||||
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.endoscopy import write_data
|
||||
from utils.softquant import soft_quant
|
||||
|
||||
class LimitedAdaptiveComb1d(nn.Module):
|
||||
COUNTER = 1
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
feature_dim,
|
||||
frame_size=160,
|
||||
overlap_size=40,
|
||||
padding=None,
|
||||
max_lag=256,
|
||||
name=None,
|
||||
gain_limit_db=10,
|
||||
global_gain_limits_db=[-6, 6],
|
||||
norm_p=2,
|
||||
softquant=False,
|
||||
apply_weight_norm=False,
|
||||
**kwargs):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
feature_dim : int
|
||||
dimension of features from which kernels, biases and gains are computed
|
||||
|
||||
frame_size : int, optional
|
||||
frame size, defaults to 160
|
||||
|
||||
overlap_size : int, optional
|
||||
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40
|
||||
|
||||
use_bias : bool, optional
|
||||
if true, biases will be added to output channels. Defaults to True
|
||||
|
||||
padding : List[int, int], optional
|
||||
left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2]
|
||||
|
||||
max_lag : int, optional
|
||||
maximal pitch lag, defaults to 256
|
||||
|
||||
have_a0 : bool, optional
|
||||
If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False
|
||||
|
||||
name: str or None, optional
|
||||
specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d
|
||||
|
||||
"""
|
||||
|
||||
super(LimitedAdaptiveComb1d, self).__init__()
|
||||
|
||||
self.in_channels = 1
|
||||
self.out_channels = 1
|
||||
self.feature_dim = feature_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.frame_size = frame_size
|
||||
self.overlap_size = overlap_size
|
||||
self.max_lag = max_lag
|
||||
self.limit_db = gain_limit_db
|
||||
self.norm_p = norm_p
|
||||
|
||||
if name is None:
|
||||
self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER)
|
||||
LimitedAdaptiveComb1d.COUNTER += 1
|
||||
else:
|
||||
self.name = name
|
||||
|
||||
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# network for generating convolution weights
|
||||
self.conv_kernel = norm(nn.Linear(feature_dim, kernel_size))
|
||||
|
||||
if softquant:
|
||||
self.conv_kernel = soft_quant(self.conv_kernel)
|
||||
|
||||
|
||||
# comb filter gain
|
||||
self.filter_gain = norm(nn.Linear(feature_dim, 1))
|
||||
self.log_gain_limit = gain_limit_db * 0.11512925464970229
|
||||
with torch.no_grad():
|
||||
self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
|
||||
|
||||
self.global_filter_gain = norm(nn.Linear(feature_dim, 1))
|
||||
log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
|
||||
self.filter_gain_a = (log_max - log_min) / 2
|
||||
self.filter_gain_b = (log_max + log_min) / 2
|
||||
|
||||
if type(padding) == type(None):
|
||||
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
|
||||
|
||||
def forward(self, x, features, lags, debug=False):
|
||||
""" adaptive 1d convolution
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.tensor
|
||||
input signal of shape (batch_size, in_channels, num_samples)
|
||||
|
||||
feathres : torch.tensor
|
||||
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
||||
|
||||
lags: torch.LongTensor
|
||||
frame-wise lags for comb-filtering
|
||||
|
||||
"""
|
||||
|
||||
batch_size = x.size(0)
|
||||
num_frames = features.size(1)
|
||||
num_samples = x.size(2)
|
||||
frame_size = self.frame_size
|
||||
overlap_size = self.overlap_size
|
||||
kernel_size = self.kernel_size
|
||||
win1 = torch.flip(self.overlap_win, [0])
|
||||
win2 = self.overlap_win
|
||||
|
||||
if num_samples // self.frame_size != num_frames:
|
||||
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
|
||||
|
||||
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
|
||||
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
|
||||
|
||||
conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
|
||||
# calculate gains
|
||||
global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
|
||||
|
||||
if debug and batch_size == 1:
|
||||
key = self.name + "_gains"
|
||||
write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_kernels"
|
||||
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_lags"
|
||||
write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_global_conv_gains"
|
||||
write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
|
||||
|
||||
# frame-wise convolution with overlap-add
|
||||
output_frames = []
|
||||
overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
|
||||
x = F.pad(x, self.padding)
|
||||
x = F.pad(x, [self.max_lag, self.overlap_size])
|
||||
|
||||
idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1)
|
||||
idx = torch.repeat_interleave(idx, batch_size, 0)
|
||||
idx = torch.repeat_interleave(idx, self.in_channels, 1)
|
||||
|
||||
|
||||
for i in range(num_frames):
|
||||
|
||||
cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1)
|
||||
xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1))
|
||||
|
||||
new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
|
||||
|
||||
offset = self.max_lag + self.padding[0]
|
||||
new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
|
||||
|
||||
# overlapping part
|
||||
output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
|
||||
|
||||
# non-overlapping part
|
||||
output_frames.append(new_chunk[:, :, overlap_size : frame_size])
|
||||
|
||||
# mem for next frame
|
||||
overlap_mem = new_chunk[:, :, frame_size :]
|
||||
|
||||
# concatenate chunks
|
||||
output = torch.cat(output_frames, dim=-1)
|
||||
|
||||
return output
|
||||
|
||||
def flop_count(self, rate):
|
||||
frame_rate = rate / self.frame_size
|
||||
overlap = self.overlap_size
|
||||
overhead = overlap / self.frame_size
|
||||
|
||||
count = 0
|
||||
|
||||
# kernel computation and filtering
|
||||
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
|
||||
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
|
||||
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
|
||||
|
||||
# a0 computation
|
||||
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
|
||||
|
||||
# windowing
|
||||
count += overlap * frame_rate * 3 * self.out_channels
|
||||
|
||||
return count
|
||||
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.endoscopy import write_data
|
||||
|
||||
from utils.ada_conv import adaconv_kernel
|
||||
from utils.softquant import soft_quant
|
||||
|
||||
class LimitedAdaptiveConv1d(nn.Module):
|
||||
COUNTER = 1
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
feature_dim,
|
||||
frame_size=160,
|
||||
overlap_size=40,
|
||||
padding=None,
|
||||
name=None,
|
||||
gain_limits_db=[-6, 6],
|
||||
shape_gain_db=0,
|
||||
norm_p=2,
|
||||
softquant=False,
|
||||
apply_weight_norm=False,
|
||||
**kwargs):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
in_channels : int
|
||||
number of input channels
|
||||
|
||||
out_channels : int
|
||||
number of output channels
|
||||
|
||||
feature_dim : int
|
||||
dimension of features from which kernels, biases and gains are computed
|
||||
|
||||
frame_size : int
|
||||
frame size
|
||||
|
||||
overlap_size : int
|
||||
overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame
|
||||
|
||||
use_bias : bool
|
||||
if true, biases will be added to output channels
|
||||
|
||||
|
||||
padding : List[int, int]
|
||||
|
||||
"""
|
||||
|
||||
super(LimitedAdaptiveConv1d, self).__init__()
|
||||
|
||||
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.feature_dim = feature_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.frame_size = frame_size
|
||||
self.overlap_size = overlap_size
|
||||
self.gain_limits_db = gain_limits_db
|
||||
self.shape_gain_db = shape_gain_db
|
||||
self.norm_p = norm_p
|
||||
|
||||
if name is None:
|
||||
self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER)
|
||||
LimitedAdaptiveConv1d.COUNTER += 1
|
||||
else:
|
||||
self.name = name
|
||||
|
||||
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# network for generating convolution weights
|
||||
self.conv_kernel = norm(nn.Linear(feature_dim, in_channels * out_channels * kernel_size))
|
||||
if softquant:
|
||||
self.conv_kernel = soft_quant(self.conv_kernel)
|
||||
|
||||
self.shape_gain = min(1, 10**(shape_gain_db / 20))
|
||||
|
||||
self.filter_gain = norm(nn.Linear(feature_dim, out_channels))
|
||||
log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
|
||||
self.filter_gain_a = (log_max - log_min) / 2
|
||||
self.filter_gain_b = (log_max + log_min) / 2
|
||||
|
||||
if type(padding) == type(None):
|
||||
self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
|
||||
|
||||
|
||||
def flop_count(self, rate):
|
||||
frame_rate = rate / self.frame_size
|
||||
overlap = self.overlap_size
|
||||
overhead = overlap / self.frame_size
|
||||
|
||||
count = 0
|
||||
|
||||
# kernel computation and filtering
|
||||
count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
|
||||
count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
|
||||
|
||||
# gain computation
|
||||
|
||||
count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
|
||||
|
||||
# windowing
|
||||
count += 3 * overlap * frame_rate * self.out_channels
|
||||
|
||||
return count
|
||||
|
||||
def forward(self, x, features, debug=False):
|
||||
""" adaptive 1d convolution
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.tensor
|
||||
input signal of shape (batch_size, in_channels, num_samples)
|
||||
|
||||
feathres : torch.tensor
|
||||
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
||||
|
||||
"""
|
||||
|
||||
batch_size = x.size(0)
|
||||
num_frames = features.size(1)
|
||||
num_samples = x.size(2)
|
||||
frame_size = self.frame_size
|
||||
overlap_size = self.overlap_size
|
||||
kernel_size = self.kernel_size
|
||||
win1 = torch.flip(self.overlap_win, [0])
|
||||
win2 = self.overlap_win
|
||||
|
||||
if num_samples // self.frame_size != num_frames:
|
||||
raise ValueError('non matching sizes in AdaptiveConv1d.forward')
|
||||
|
||||
conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
|
||||
|
||||
# normalize kernels (TODO: switch to L1 and normalize over kernel and input channel dimension)
|
||||
conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=[-2, -1], keepdim=True))
|
||||
|
||||
# limit shape
|
||||
id_kernels = torch.zeros_like(conv_kernels)
|
||||
id_kernels[..., self.padding[1]] = 1
|
||||
|
||||
conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels
|
||||
|
||||
# calculate gains
|
||||
conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features)) + self.filter_gain_b)
|
||||
if debug and batch_size == 1:
|
||||
key = self.name + "_gains"
|
||||
write_data(key, conv_gains.permute(0, 2, 1).detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
key = self.name + "_kernels"
|
||||
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
|
||||
|
||||
|
||||
conv_kernels = conv_kernels * conv_gains.view(batch_size, num_frames, self.out_channels, 1, 1)
|
||||
|
||||
conv_kernels = conv_kernels.permute(0, 2, 3, 1, 4)
|
||||
|
||||
output = adaconv_kernel(x, conv_kernels, win1, fft_size=256)
|
||||
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
|
||||
class NoiseShaper(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
feature_dim,
|
||||
frame_size=160
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
feature_dim : int
|
||||
dimension of input features
|
||||
|
||||
frame_size : int
|
||||
frame size
|
||||
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.frame_size = frame_size
|
||||
|
||||
# feature transform
|
||||
self.feature_alpha1 = nn.Conv1d(self.feature_dim, frame_size, 2)
|
||||
self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2)
|
||||
|
||||
|
||||
def flop_count(self, rate):
|
||||
|
||||
frame_rate = rate / self.frame_size
|
||||
|
||||
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
|
||||
|
||||
return shape_flops
|
||||
|
||||
|
||||
def forward(self, features):
|
||||
""" creates temporally shaped noise
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
features : torch.tensor
|
||||
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
||||
|
||||
"""
|
||||
|
||||
batch_size = features.size(0)
|
||||
num_frames = features.size(1)
|
||||
frame_size = self.frame_size
|
||||
num_samples = num_frames * frame_size
|
||||
|
||||
# feature path
|
||||
f = F.pad(features.permute(0, 2, 1), [1, 0])
|
||||
alpha = F.leaky_relu(self.feature_alpha1(f), 0.2)
|
||||
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
|
||||
alpha = alpha.permute(0, 2, 1)
|
||||
|
||||
# signal generation
|
||||
y = torch.randn((batch_size, num_frames, frame_size), dtype=features.dtype, device=features.device)
|
||||
y = alpha * y
|
||||
|
||||
return y.reshape(batch_size, 1, num_samples)
|
||||
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PitchAutoCorrelator(nn.Module):
|
||||
def __init__(self,
|
||||
frame_size=80,
|
||||
pitch_min=32,
|
||||
pitch_max=300,
|
||||
radius=2):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.frame_size = frame_size
|
||||
self.pitch_min = pitch_min
|
||||
self.pitch_max = pitch_max
|
||||
self.radius = radius
|
||||
|
||||
|
||||
def forward(self, x, periods):
|
||||
# x of shape (batch_size, channels, num_samples)
|
||||
# periods of shape (batch_size, num_frames)
|
||||
|
||||
num_frames = periods.size(1)
|
||||
batch_size = periods.size(0)
|
||||
num_samples = self.frame_size * num_frames
|
||||
channels = x.size(1)
|
||||
|
||||
assert num_samples == x.size(-1)
|
||||
|
||||
range = torch.arange(-self.radius, self.radius + 1, device=x.device)
|
||||
idx = torch.arange(self.frame_size * num_frames, device=x.device)
|
||||
p_up = torch.repeat_interleave(periods, self.frame_size, 1)
|
||||
lookup = idx + self.pitch_max - p_up
|
||||
lookup = lookup.unsqueeze(-1) + range
|
||||
lookup = lookup.unsqueeze(1)
|
||||
|
||||
# padding
|
||||
x_pad = F.pad(x, [self.pitch_max, 0])
|
||||
x_ext = torch.repeat_interleave(x_pad.unsqueeze(-1), 2 * self.radius + 1, -1)
|
||||
|
||||
# framing
|
||||
x_select = torch.gather(x_ext, 2, lookup)
|
||||
x_frames = x_pad[..., self.pitch_max : ].reshape(batch_size, channels, num_frames, self.frame_size, 1)
|
||||
lag_frames = x_select.reshape(batch_size, 1, num_frames, self.frame_size, -1)
|
||||
|
||||
# calculate auto-correlation
|
||||
dotp = torch.sum(x_frames * lag_frames, dim=-2)
|
||||
frame_nrg = torch.sum(x_frames * x_frames, dim=-2)
|
||||
lag_frame_nrg = torch.sum(lag_frames * lag_frames, dim=-2)
|
||||
|
||||
acorr = dotp / torch.sqrt(frame_nrg * lag_frame_nrg + 1e-9)
|
||||
|
||||
return acorr
|
||||
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
frac_fir = np.array(
|
||||
[
|
||||
[189, -600, 617, 30567, 2996, -1375, 425, -46],
|
||||
[117, -159, -1070, 29704, 5784, -2143, 611, -71],
|
||||
[52, 221, -2392, 28276, 8798, -2865, 773, -91],
|
||||
[-4, 529, -3350, 26341, 11950, -3487, 896, -103],
|
||||
[-48, 758, -3956, 23973, 15143, -3957, 967, -107],
|
||||
[-80, 905, -4235, 21254, 18278, -4222, 972, -99],
|
||||
[-99, 972, -4222, 18278, 21254, -4235, 905, -80],
|
||||
[-107, 967, -3957, 15143, 23973, -3956, 758, -48],
|
||||
[-103, 896, -3487, 11950, 26341, -3350, 529, -4],
|
||||
[-91, 773, -2865, 8798, 28276, -2392, 221, 52],
|
||||
[-71, 611, -2143, 5784, 29704, -1070, -159, 117],
|
||||
[-46, 425, -1375, 2996, 30567, 617, -600, 189]
|
||||
],
|
||||
dtype=np.float32
|
||||
) / 2**15
|
||||
|
||||
|
||||
hq_2x_up_c_even = [x / 2**16 for x in [1746, 14986, 39083 - 65536]]
|
||||
hq_2x_up_c_odd = [x / 2**16 for x in [6854, 25769, 55542 - 65536]]
|
||||
|
||||
|
||||
def get_impz(coeffs, n):
|
||||
s = 3*[0]
|
||||
y = np.zeros(n)
|
||||
x = 1
|
||||
|
||||
for i in range(n):
|
||||
Y = x - s[0]
|
||||
X = Y * coeffs[0]
|
||||
tmp1 = s[0] + X
|
||||
s[0] = x + X
|
||||
|
||||
Y = tmp1 - s[1]
|
||||
X = Y * coeffs[1]
|
||||
tmp2 = s[1] + X
|
||||
s[1] = tmp1 + X
|
||||
|
||||
Y = tmp2 - s[2]
|
||||
X = Y * (1 + coeffs[2])
|
||||
tmp3 = s[2] + X
|
||||
s[2] = tmp2 + X
|
||||
|
||||
y[i] = tmp3
|
||||
x = 0
|
||||
|
||||
return y
|
||||
|
||||
|
||||
|
||||
class SilkUpsampler(nn.Module):
|
||||
SUPPORTED_TARGET_RATES = {24000, 48000}
|
||||
SUPPORTED_SOURCE_RATES = {16000}
|
||||
def __init__(self,
|
||||
fs_in=16000,
|
||||
fs_out=48000):
|
||||
|
||||
super().__init__()
|
||||
self.fs_in = fs_in
|
||||
self.fs_out = fs_out
|
||||
|
||||
if fs_in not in self.SUPPORTED_SOURCE_RATES:
|
||||
raise ValueError(f'SilkUpsampler currently only supports upsampling from {self.SUPPORTED_SOURCE_RATES} Hz')
|
||||
|
||||
|
||||
if fs_out not in self.SUPPORTED_TARGET_RATES:
|
||||
raise ValueError(f'SilkUpsampler currently only supports upsampling to {self.SUPPORTED_TARGET_RATES} Hz')
|
||||
|
||||
|
||||
# hq 2x upsampler as FIR approximation
|
||||
hq_2x_up_even = get_impz(hq_2x_up_c_even, 128)[::-1].copy()
|
||||
hq_2x_up_odd = get_impz(hq_2x_up_c_odd , 128)[::-1].copy()
|
||||
|
||||
self.hq_2x_up_even = nn.Parameter(torch.from_numpy(hq_2x_up_even).float().view(1, 1, -1), requires_grad=False)
|
||||
self.hq_2x_up_odd = nn.Parameter(torch.from_numpy(hq_2x_up_odd ).float().view(1, 1, -1), requires_grad=False)
|
||||
self.hq_2x_up_padding = [127, 0]
|
||||
|
||||
# interpolation filters
|
||||
frac_01_24 = frac_fir[0]
|
||||
frac_17_24 = frac_fir[8]
|
||||
frac_09_24 = frac_fir[4]
|
||||
|
||||
self.frac_01_24 = nn.Parameter(torch.from_numpy(frac_01_24).view(1, 1, -1), requires_grad=False)
|
||||
self.frac_17_24 = nn.Parameter(torch.from_numpy(frac_17_24).view(1, 1, -1), requires_grad=False)
|
||||
self.frac_09_24 = nn.Parameter(torch.from_numpy(frac_09_24).view(1, 1, -1), requires_grad=False)
|
||||
|
||||
self.stride = 1 if fs_out == 48000 else 2
|
||||
|
||||
def hq_2x_up(self, x):
|
||||
|
||||
num_channels = x.size(1)
|
||||
|
||||
weight_even = torch.repeat_interleave(self.hq_2x_up_even, num_channels, 0)
|
||||
weight_odd = torch.repeat_interleave(self.hq_2x_up_odd , num_channels, 0)
|
||||
|
||||
x_pad = F.pad(x, self.hq_2x_up_padding)
|
||||
y_even = F.conv1d(x_pad, weight_even, groups=num_channels)
|
||||
y_odd = F.conv1d(x_pad, weight_odd , groups=num_channels)
|
||||
|
||||
y = torch.cat((y_even.unsqueeze(-1), y_odd.unsqueeze(-1)), dim=-1).flatten(2)
|
||||
|
||||
return y
|
||||
|
||||
def interpolate_3_2(self, x):
|
||||
|
||||
num_channels = x.size(1)
|
||||
|
||||
weight_01_24 = torch.repeat_interleave(self.frac_01_24, num_channels, 0)
|
||||
weight_17_24 = torch.repeat_interleave(self.frac_17_24, num_channels, 0)
|
||||
weight_09_24 = torch.repeat_interleave(self.frac_09_24, num_channels, 0)
|
||||
|
||||
x_pad = F.pad(x, [8, 0])
|
||||
y_01_24 = F.conv1d(x_pad, weight_01_24, stride=2, groups=num_channels)
|
||||
y_17_24 = F.conv1d(x_pad, weight_17_24, stride=2, groups=num_channels)
|
||||
y_09_24_sh1 = F.conv1d(torch.roll(x_pad, -1, -1), weight_09_24, stride=2, groups=num_channels)
|
||||
|
||||
|
||||
y = torch.cat(
|
||||
(y_01_24.unsqueeze(-1), y_17_24.unsqueeze(-1), y_09_24_sh1.unsqueeze(-1)),
|
||||
dim=-1).flatten(2)
|
||||
|
||||
return y[..., :-3]
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
y_2x = self.hq_2x_up(x)
|
||||
y_3x = self.interpolate_3_2(y_2x)
|
||||
|
||||
return y_3x[:, :, ::self.stride]
|
||||
@@ -0,0 +1,145 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.complexity import _conv1d_flop_count
|
||||
from utils.softquant import soft_quant
|
||||
|
||||
class TDShaper(nn.Module):
|
||||
COUNTER = 1
|
||||
|
||||
def __init__(self,
|
||||
feature_dim,
|
||||
frame_size=160,
|
||||
avg_pool_k=4,
|
||||
innovate=False,
|
||||
pool_after=False,
|
||||
softquant=False,
|
||||
apply_weight_norm=False
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
|
||||
feature_dim : int
|
||||
dimension of input features
|
||||
|
||||
frame_size : int
|
||||
frame size
|
||||
|
||||
avg_pool_k : int, optional
|
||||
kernel size and stride for avg pooling
|
||||
|
||||
padding : List[int, int]
|
||||
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.frame_size = frame_size
|
||||
self.avg_pool_k = avg_pool_k
|
||||
self.innovate = innovate
|
||||
self.pool_after = pool_after
|
||||
|
||||
assert frame_size % avg_pool_k == 0
|
||||
self.env_dim = frame_size // avg_pool_k + 1
|
||||
|
||||
norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
|
||||
|
||||
# feature transform
|
||||
self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
|
||||
self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))
|
||||
self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2))
|
||||
|
||||
if softquant:
|
||||
self.feature_alpha1_f = soft_quant(self.feature_alpha1_f)
|
||||
|
||||
if self.innovate:
|
||||
self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
|
||||
self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
|
||||
|
||||
self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2))
|
||||
self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2))
|
||||
|
||||
|
||||
def flop_count(self, rate):
|
||||
|
||||
frame_rate = rate / self.frame_size
|
||||
|
||||
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
|
||||
|
||||
if self.innovate:
|
||||
inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
|
||||
else:
|
||||
inno_flops = 0
|
||||
|
||||
return shape_flops + inno_flops
|
||||
|
||||
def envelope_transform(self, x):
|
||||
|
||||
x = torch.abs(x)
|
||||
if self.pool_after:
|
||||
x = torch.log(x + .5**16)
|
||||
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
|
||||
else:
|
||||
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
|
||||
x = torch.log(x + .5**16)
|
||||
|
||||
x = x.reshape(x.size(0), -1, self.env_dim - 1)
|
||||
avg_x = torch.mean(x, -1, keepdim=True)
|
||||
|
||||
x = torch.cat((x - avg_x, avg_x), dim=-1)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, features, debug=False):
|
||||
""" innovate signal parts with temporal shaping
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.tensor
|
||||
input signal of shape (batch_size, 1, num_samples)
|
||||
|
||||
features : torch.tensor
|
||||
frame-wise features of shape (batch_size, num_frames, feature_dim)
|
||||
|
||||
"""
|
||||
|
||||
batch_size = x.size(0)
|
||||
num_frames = features.size(1)
|
||||
num_samples = x.size(2)
|
||||
frame_size = self.frame_size
|
||||
|
||||
# generate temporal envelope
|
||||
tenv = self.envelope_transform(x)
|
||||
|
||||
# feature path
|
||||
f = F.pad(features.permute(0, 2, 1), [1, 0])
|
||||
t = F.pad(tenv.permute(0, 2, 1), [1, 0])
|
||||
alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t)
|
||||
alpha = F.leaky_relu(alpha, 0.2)
|
||||
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
|
||||
alpha = alpha.permute(0, 2, 1)
|
||||
|
||||
if self.innovate:
|
||||
inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2)
|
||||
inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0])))
|
||||
inno_alpha = inno_alpha.permute(0, 2, 1)
|
||||
|
||||
inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2)
|
||||
inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0])))
|
||||
inno_x = inno_x.permute(0, 2, 1)
|
||||
|
||||
# signal path
|
||||
y = x.reshape(batch_size, num_frames, -1)
|
||||
y = alpha * y
|
||||
|
||||
if self.innovate:
|
||||
y = y + inno_alpha * inno_x
|
||||
|
||||
return y.reshape(batch_size, 1, num_samples)
|
||||
@@ -0,0 +1,112 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def load_lpcnet_features(feature_file, version=2):
|
||||
if version == 2:
|
||||
layout = {
|
||||
'cepstrum': [0,18],
|
||||
'periods': [18, 19],
|
||||
'pitch_corr': [19, 20],
|
||||
'lpc': [20, 36]
|
||||
}
|
||||
frame_length = 36
|
||||
|
||||
elif version == 1:
|
||||
layout = {
|
||||
'cepstrum': [0,18],
|
||||
'periods': [36, 37],
|
||||
'pitch_corr': [37, 38],
|
||||
'lpc': [39, 55],
|
||||
}
|
||||
frame_length = 55
|
||||
else:
|
||||
raise ValueError(f'unknown feature version: {version}')
|
||||
|
||||
|
||||
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
|
||||
raw_features = raw_features.reshape((-1, frame_length))
|
||||
|
||||
features = torch.cat(
|
||||
[
|
||||
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
|
||||
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
|
||||
],
|
||||
dim=1
|
||||
)
|
||||
|
||||
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
|
||||
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
|
||||
|
||||
|
||||
|
||||
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
|
||||
ref_data = np.memmap(reference_data_path, dtype=np.int16)
|
||||
signal = np.memmap(signal_path, dtype=np.int16)
|
||||
|
||||
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
|
||||
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
|
||||
|
||||
|
||||
assert len(signal) % 160 == 0
|
||||
num_frames = len(signal) // 160
|
||||
mem = np.zeros(1)
|
||||
for fr in range(len(signal)//160):
|
||||
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
|
||||
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
|
||||
|
||||
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
|
||||
|
||||
new_data[:] = 0
|
||||
N = len(signal) - offset
|
||||
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
|
||||
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
|
||||
|
||||
|
||||
def parse_warpq_scores(output_file):
|
||||
""" extracts warpq scores from output file """
|
||||
|
||||
with open(output_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def parse_stats_file(file):
|
||||
|
||||
with open(file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
mean = float(lines[0].split(":")[-1])
|
||||
bt_mean = float(lines[1].split(":")[-1])
|
||||
top_mean = float(lines[2].split(":")[-1])
|
||||
|
||||
return mean, bt_mean, top_mean
|
||||
|
||||
def collect_test_stats(test_folder):
|
||||
""" collects statistics for all discovered metrics from test folder """
|
||||
|
||||
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
|
||||
|
||||
results = dict()
|
||||
|
||||
content = os.listdir(test_folder)
|
||||
|
||||
stats_files = [file for file in content if file.startswith('stats_')]
|
||||
|
||||
for file in stats_files:
|
||||
metric = file[len("stats_") : -len(".txt")]
|
||||
|
||||
if metric not in metrics:
|
||||
print(f"warning: unknown metric {metric}")
|
||||
|
||||
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
|
||||
|
||||
results[metric] = [mean, bt_mean, top_mean]
|
||||
|
||||
return results
|
||||
95
managed_components/78__esp-opus/dnn/torch/osce/utils/misc.py
Normal file
95
managed_components/78__esp-opus/dnn/torch/osce/utils/misc.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
|
||||
def count_parameters(model, verbose=False):
|
||||
total = 0
|
||||
for name, p in model.named_parameters():
|
||||
count = torch.ones_like(p).sum().item()
|
||||
|
||||
if verbose:
|
||||
print(f"{name}: {count} parameters")
|
||||
|
||||
total += count
|
||||
|
||||
return total
|
||||
|
||||
def count_nonzero_parameters(model, verbose=False):
|
||||
total = 0
|
||||
for name, p in model.named_parameters():
|
||||
count = torch.count_nonzero(p).item()
|
||||
|
||||
if verbose:
|
||||
print(f"{name}: {count} non-zero parameters")
|
||||
|
||||
total += count
|
||||
|
||||
return total
|
||||
def retain_grads(module):
|
||||
for p in module.parameters():
|
||||
if p.requires_grad:
|
||||
p.retain_grad()
|
||||
|
||||
def get_grad_norm(module, p=2):
|
||||
norm = 0
|
||||
for param in module.parameters():
|
||||
if param.requires_grad:
|
||||
norm = norm + (torch.abs(param.grad) ** p).sum()
|
||||
|
||||
return norm ** (1/p)
|
||||
|
||||
def create_weights(s_real, s_gen, alpha):
|
||||
weights = []
|
||||
with torch.no_grad():
|
||||
for sr, sg in zip(s_real, s_gen):
|
||||
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
|
||||
weights.append(weight)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def _get_candidates(module: torch.nn.Module):
|
||||
candidates = []
|
||||
for key in module.__dict__.keys():
|
||||
if hasattr(module, key + '_v'):
|
||||
candidates.append(key)
|
||||
return candidates
|
||||
|
||||
def remove_all_weight_norm(model : torch.nn.Module, verbose=False):
|
||||
for name, m in model.named_modules():
|
||||
candidates = _get_candidates(m)
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
remove_weight_norm(m, name=candidate)
|
||||
if verbose: print(f'removed weight norm on weight {name}.{candidate}')
|
||||
except:
|
||||
pass
|
||||
153
managed_components/78__esp-opus/dnn/torch/osce/utils/moc.py
Normal file
153
managed_components/78__esp-opus/dnn/torch/osce/utils/moc.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
|
||||
def compute_vad_mask(x, fs, stop_db=-70):
|
||||
|
||||
frame_length = (fs + 49) // 50
|
||||
x = x[: frame_length * (len(x) // frame_length)]
|
||||
|
||||
frames = x.reshape(-1, frame_length)
|
||||
frame_energy = np.sum(frames ** 2, axis=1)
|
||||
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
|
||||
|
||||
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
|
||||
vactive = np.ones_like(frames)
|
||||
vactive[frame_energy_smooth < max_threshold, :] = 0
|
||||
vactive = vactive.reshape(-1)
|
||||
|
||||
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
|
||||
filter = filter / filter.sum()
|
||||
|
||||
mask = np.convolve(vactive, filter, mode='same')
|
||||
|
||||
return x, mask
|
||||
|
||||
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
|
||||
num_samples = frame_size + (num_frames - 1) * hop_size
|
||||
if len(mask) < num_samples:
|
||||
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
|
||||
else:
|
||||
mask = mask[:num_samples]
|
||||
|
||||
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
|
||||
|
||||
return new_mask
|
||||
|
||||
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
|
||||
num_spectra = (len(x) - window_size - hop_size) // hop_size
|
||||
window = scipy.signal.get_window(window, window_size)
|
||||
N = window_size // 2
|
||||
|
||||
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
|
||||
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
|
||||
|
||||
return psd
|
||||
|
||||
|
||||
def frequency_mask(num_bands, up_factor, down_factor):
|
||||
|
||||
up_mask = np.zeros((num_bands, num_bands))
|
||||
down_mask = np.zeros((num_bands, num_bands))
|
||||
|
||||
for i in range(num_bands):
|
||||
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
|
||||
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
|
||||
|
||||
return down_mask @ up_mask
|
||||
|
||||
|
||||
def rect_fb(band_limits, num_bins=None):
|
||||
num_bands = len(band_limits) - 1
|
||||
if num_bins is None:
|
||||
num_bins = band_limits[-1]
|
||||
|
||||
fb = np.zeros((num_bands, num_bins))
|
||||
for i in range(num_bands):
|
||||
fb[i, band_limits[i]:band_limits[i+1]] = 1
|
||||
|
||||
return fb
|
||||
|
||||
|
||||
def compare(x, y, apply_vad=False):
|
||||
""" Modified version of opus_compare for 16 kHz mono signals
|
||||
|
||||
Args:
|
||||
x (np.ndarray): reference input signal scaled to [-1, 1]
|
||||
y (np.ndarray): test signal scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
float: perceptually weighted error
|
||||
"""
|
||||
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
|
||||
band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
|
||||
num_bands = len(band_limits) - 1
|
||||
fb = rect_fb(band_limits, num_bins=81)
|
||||
|
||||
# trim samples to same size
|
||||
num_samples = min(len(x), len(y))
|
||||
x = x[:num_samples] * 2**15
|
||||
y = y[:num_samples] * 2**15
|
||||
|
||||
psd_x = power_spectrum(x) + 100000
|
||||
psd_y = power_spectrum(y) + 100000
|
||||
|
||||
num_frames = psd_x.shape[0]
|
||||
|
||||
# average band energies
|
||||
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
|
||||
|
||||
# frequecy masking
|
||||
f_mask = frequency_mask(num_bands, 0.1, 0.03)
|
||||
mask_x = be_x @ f_mask.T
|
||||
|
||||
# temporal masking
|
||||
for i in range(1, num_frames):
|
||||
mask_x[i, :] += 0.5 * mask_x[i-1, :]
|
||||
|
||||
# apply mask
|
||||
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
|
||||
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
|
||||
|
||||
# 2-frame average
|
||||
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
|
||||
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
|
||||
|
||||
# distortion metric
|
||||
re = masked_psd_y / masked_psd_x
|
||||
im = np.log(re) ** 2
|
||||
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
|
||||
Ef = np.mean(Eb , axis=1)
|
||||
|
||||
if apply_vad:
|
||||
_, mask = compute_vad_mask(x, 16000)
|
||||
mask = convert_mask(mask, Ef.shape[0])
|
||||
else:
|
||||
mask = np.ones_like(Ef)
|
||||
|
||||
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
|
||||
|
||||
return float(err)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from scipy.io import wavfile
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('ref', type=str, help='reference wav file')
|
||||
parser.add_argument('deg', type=str, help='degraded wav file')
|
||||
parser.add_argument('--apply-vad', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
fs1, x = wavfile.read(args.ref)
|
||||
fs2, y = wavfile.read(args.deg)
|
||||
|
||||
if max(fs1, fs2) != 16000:
|
||||
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
|
||||
|
||||
x = x.astype(np.float32) / 2**15
|
||||
y = y.astype(np.float32) / 2**15
|
||||
|
||||
err = compare(x, y, apply_vad=args.apply_vad)
|
||||
|
||||
print(f"MOC: {err}")
|
||||
122
managed_components/78__esp-opus/dnn/torch/osce/utils/pitch.py
Normal file
122
managed_components/78__esp-opus/dnn/torch/osce/utils/pitch.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
def hangover(lags, num_frames=10):
|
||||
lags = lags.copy()
|
||||
count = 0
|
||||
last_lag = 0
|
||||
|
||||
for i in range(len(lags)):
|
||||
lag = lags[i]
|
||||
|
||||
if lag == 0:
|
||||
if count < num_frames:
|
||||
lags[i] = last_lag
|
||||
count += 1
|
||||
else:
|
||||
count = 0
|
||||
last_lag = lag
|
||||
|
||||
return lags
|
||||
|
||||
|
||||
def smooth_pitch_lags(lags, d=2):
|
||||
|
||||
assert d < 4
|
||||
|
||||
num_silk_frames = len(lags) // 4
|
||||
|
||||
smoothed_lags = lags.copy()
|
||||
|
||||
tmp = np.arange(1, d+1)
|
||||
kernel = np.concatenate((tmp, [d+1], tmp[::-1]), dtype=np.float32)
|
||||
kernel = kernel / np.sum(kernel)
|
||||
|
||||
last = lags[0:d][::-1]
|
||||
for i in range(num_silk_frames):
|
||||
frame = lags[i * 4: (i+1) * 4]
|
||||
|
||||
if np.max(np.abs(frame)) == 0:
|
||||
last = frame[4-d:]
|
||||
continue
|
||||
|
||||
if i == num_silk_frames - 1:
|
||||
next = frame[4-d:][::-1]
|
||||
else:
|
||||
next = lags[(i+1) * 4 : (i+1) * 4 + d]
|
||||
|
||||
if np.max(np.abs(next)) == 0:
|
||||
next = frame[4-d:][::-1]
|
||||
|
||||
if np.max(np.abs(last)) == 0:
|
||||
last = frame[0:d][::-1]
|
||||
|
||||
smoothed_frame = np.convolve(np.concatenate((last, frame, next), dtype=np.float32), kernel, mode='valid')
|
||||
|
||||
smoothed_lags[i * 4: (i+1) * 4] = np.round(smoothed_frame)
|
||||
|
||||
last = frame[4-d:]
|
||||
|
||||
return smoothed_lags
|
||||
|
||||
def calculate_acorr_window(x, frame_size, lags, history=None, max_lag=300, radius=2, add_double_lag_acorr=False, no_pitch_threshold=32):
|
||||
eps = 1e-9
|
||||
|
||||
lag_multiplier = 2 if add_double_lag_acorr else 1
|
||||
|
||||
if history is None:
|
||||
history = np.zeros(lag_multiplier * max_lag + radius, dtype=x.dtype)
|
||||
|
||||
offset = len(history)
|
||||
|
||||
assert offset >= max_lag + radius
|
||||
assert len(x) % frame_size == 0
|
||||
|
||||
num_frames = len(x) // frame_size
|
||||
lags = lags.copy()
|
||||
|
||||
x_ext = np.concatenate((history, x), dtype=x.dtype)
|
||||
|
||||
d = radius
|
||||
num_acorrs = 2 * d + 1
|
||||
acorrs = np.zeros((num_frames, lag_multiplier * num_acorrs), dtype=x.dtype)
|
||||
|
||||
for idx in range(num_frames):
|
||||
lag = lags[idx].item()
|
||||
frame = x_ext[offset + idx * frame_size : offset + (idx + 1) * frame_size]
|
||||
|
||||
for k in range(lag_multiplier):
|
||||
lag1 = (k + 1) * lag if lag >= no_pitch_threshold else lag
|
||||
for j in range(num_acorrs):
|
||||
past = x_ext[offset + idx * frame_size - lag1 + j - d : offset + (idx + 1) * frame_size - lag1 + j - d]
|
||||
acorrs[idx, j + k * num_acorrs] = np.dot(frame, past) / np.sqrt(np.dot(frame, frame) * np.dot(past, past) + eps)
|
||||
|
||||
return acorrs, lags
|
||||
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import scipy
|
||||
import scipy.signal
|
||||
|
||||
from utils.pitch import hangover, calculate_acorr_window
|
||||
from utils.spec import create_filter_bank, cepstrum, log_spectrum, log_spectrum_from_lpc
|
||||
|
||||
def spec_from_lpc(a, n_fft=128, eps=1e-9):
|
||||
order = a.shape[-1]
|
||||
assert order + 1 < n_fft
|
||||
|
||||
x = np.zeros((*a.shape[:-1], n_fft ))
|
||||
x[..., 0] = 1
|
||||
x[..., 1:1 + order] = -a
|
||||
|
||||
X = np.fft.fft(x, axis=-1)
|
||||
X = np.abs(X[..., :n_fft//2 + 1]) ** 2
|
||||
|
||||
S = 1 / (X + eps)
|
||||
|
||||
return S
|
||||
|
||||
def silk_feature_factory(no_pitch_value=256,
|
||||
acorr_radius=2,
|
||||
pitch_hangover=8,
|
||||
num_bands_clean_spec=64,
|
||||
num_bands_noisy_spec=18,
|
||||
noisy_spec_scale='opus',
|
||||
noisy_apply_dct=True,
|
||||
add_double_lag_acorr=False
|
||||
):
|
||||
|
||||
w = scipy.signal.windows.cosine(320)
|
||||
fb_clean_spec = create_filter_bank(num_bands_clean_spec, 320, scale='erb', round_center_bins=True, normalize=True)
|
||||
fb_noisy_spec = create_filter_bank(num_bands_noisy_spec, 320, scale=noisy_spec_scale, round_center_bins=True, normalize=True)
|
||||
|
||||
def create_features(noisy, noisy_history, lpcs, gains, ltps, periods):
|
||||
|
||||
periods = periods.copy()
|
||||
|
||||
if pitch_hangover > 0:
|
||||
periods = hangover(periods, num_frames=pitch_hangover)
|
||||
|
||||
periods[periods == 0] = no_pitch_value
|
||||
|
||||
clean_spectrum = 0.3 * log_spectrum_from_lpc(lpcs, fb=fb_clean_spec, n_fft=320)
|
||||
|
||||
if noisy_apply_dct:
|
||||
noisy_cepstrum = np.repeat(
|
||||
cepstrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
|
||||
else:
|
||||
noisy_cepstrum = np.repeat(
|
||||
log_spectrum(np.concatenate((noisy_history[-160:], noisy), dtype=np.float32), 320, fb_noisy_spec, w), 2, 0)
|
||||
|
||||
log_gains = np.log(gains + 1e-9).reshape(-1, 1)
|
||||
|
||||
acorr, _ = calculate_acorr_window(noisy, 80, periods, noisy_history, radius=acorr_radius, add_double_lag_acorr=add_double_lag_acorr)
|
||||
|
||||
features = np.concatenate((clean_spectrum, noisy_cepstrum, acorr, ltps, log_gains), axis=-1, dtype=np.float32)
|
||||
|
||||
return features, periods.astype(np.int64)
|
||||
|
||||
return create_features
|
||||
|
||||
|
||||
|
||||
def load_inference_data(path,
|
||||
no_pitch_value=256,
|
||||
skip=92,
|
||||
preemph=0.85,
|
||||
acorr_radius=2,
|
||||
pitch_hangover=8,
|
||||
num_bands_clean_spec=64,
|
||||
num_bands_noisy_spec=18,
|
||||
noisy_spec_scale='opus',
|
||||
noisy_apply_dct=True,
|
||||
add_double_lag_acorr=False,
|
||||
**kwargs):
|
||||
|
||||
print(f"[load_inference_data]: ignoring keyword arguments {kwargs.keys()}...")
|
||||
|
||||
lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
|
||||
ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
|
||||
gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
|
||||
periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
|
||||
num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32).astype(np.float32).reshape(-1, 1)
|
||||
num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32).reshape(-1, 1)
|
||||
|
||||
# load signal, add back delay and pre-emphasize
|
||||
signal = np.fromfile(os.path.join(path, 'noisy.s16'), dtype=np.int16).astype(np.float32) / (2 ** 15)
|
||||
signal = np.concatenate((np.zeros(skip, dtype=np.float32), signal), dtype=np.float32)
|
||||
|
||||
create_features = silk_feature_factory(no_pitch_value, acorr_radius, pitch_hangover, num_bands_clean_spec, num_bands_noisy_spec, noisy_spec_scale, noisy_apply_dct, add_double_lag_acorr)
|
||||
|
||||
num_frames = min((len(signal) // 320) * 4, len(lpcs))
|
||||
signal = signal[: num_frames * 80]
|
||||
lpcs = lpcs[: num_frames]
|
||||
ltps = ltps[: num_frames]
|
||||
gains = gains[: num_frames]
|
||||
periods = periods[: num_frames]
|
||||
num_bits = num_bits[: num_frames // 4]
|
||||
num_bits_smooth = num_bits[: num_frames // 4]
|
||||
|
||||
numbits = np.repeat(np.concatenate((num_bits, num_bits_smooth), axis=-1, dtype=np.float32), 4, axis=0)
|
||||
|
||||
features, periods = create_features(signal, np.zeros(350, dtype=signal.dtype), lpcs, gains, ltps, periods)
|
||||
|
||||
if preemph > 0:
|
||||
signal[1:] -= preemph * signal[:-1]
|
||||
|
||||
return torch.from_numpy(signal), torch.from_numpy(features), torch.from_numpy(periods), torch.from_numpy(numbits)
|
||||
@@ -0,0 +1,110 @@
|
||||
import torch
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_optimal_scale(weight):
|
||||
with torch.no_grad():
|
||||
n_out, n_in = weight.shape
|
||||
assert n_in % 4 == 0
|
||||
if n_out % 8:
|
||||
# add padding
|
||||
pad = n_out - n_out % 8
|
||||
weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
|
||||
|
||||
weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
|
||||
weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
|
||||
scale_max = weight_max_abs / 127
|
||||
scale_sum = weight_max_sum / 129
|
||||
|
||||
scale = torch.maximum(scale_max, scale_sum)
|
||||
|
||||
return scale[:n_out]
|
||||
|
||||
@torch.no_grad()
|
||||
def q_scaled_noise(module, weight):
|
||||
if isinstance(module, torch.nn.Conv1d):
|
||||
w = weight.permute(0, 2, 1).flatten(1)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
|
||||
elif isinstance(module, torch.nn.ConvTranspose1d):
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(2, 1, 0).reshape(k * o, i)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(k, o, i).permute(2, 1, 0)
|
||||
elif len(weight.shape) == 2:
|
||||
noise = torch.rand_like(weight) - 0.5
|
||||
scale = compute_optimal_scale(weight)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError('unknown quantization setting')
|
||||
|
||||
return noise
|
||||
|
||||
class SoftQuant:
|
||||
name: str
|
||||
|
||||
def __init__(self, names: str, scale: float) -> None:
|
||||
self.names = names
|
||||
self.quantization_noise = None
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, module, inputs, *args, before=True):
|
||||
if not module.training: return
|
||||
|
||||
if before:
|
||||
self.quantization_noise = dict()
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
if self.scale is None:
|
||||
self.quantization_noise[name] = q_scaled_noise(module, weight)
|
||||
else:
|
||||
self.quantization_noise[name] = \
|
||||
self.scale * weight.abs().max() * (torch.rand_like(weight) - 0.5)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight + self.quantization_noise[name]
|
||||
else:
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight - self.quantization_noise[name]
|
||||
self.quantization_noise = None
|
||||
|
||||
def apply(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant(names, scale)
|
||||
|
||||
for name in names:
|
||||
if not hasattr(module, name):
|
||||
raise ValueError("")
|
||||
|
||||
fn_before = lambda *x : fn(*x, before=True)
|
||||
fn_after = lambda *x : fn(*x, before=False)
|
||||
setattr(fn_before, 'sqm', fn)
|
||||
setattr(fn_after, 'sqm', fn)
|
||||
|
||||
|
||||
module.register_forward_pre_hook(fn_before)
|
||||
module.register_forward_hook(fn_after)
|
||||
|
||||
module
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def soft_quant(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant.apply(module, names, scale)
|
||||
return module
|
||||
|
||||
def remove_soft_quant(module, names=['weight']):
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_pre_hooks[k]
|
||||
for k, hook in module._forward_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_hooks[k]
|
||||
|
||||
return module
|
||||
210
managed_components/78__esp-opus/dnn/torch/osce/utils/spec.py
Normal file
210
managed_components/78__esp-opus/dnn/torch/osce/utils/spec.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import math as m
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.fftpack
|
||||
import torch
|
||||
|
||||
def erb(f):
|
||||
return 24.7 * (4.37 * f + 1)
|
||||
|
||||
def inv_erb(e):
|
||||
return (e / 24.7 - 1) / 4.37
|
||||
|
||||
def bark(f):
|
||||
return 6 * m.asinh(f/600)
|
||||
|
||||
def inv_bark(b):
|
||||
return 600 * m.sinh(b / 6)
|
||||
|
||||
|
||||
scale_dict = {
|
||||
'bark': [bark, inv_bark],
|
||||
'erb': [erb, inv_erb]
|
||||
}
|
||||
|
||||
def gen_filterbank(N, Fs=16000, keep_size=False):
|
||||
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
|
||||
M = N + 1 if keep_size else N
|
||||
out_freq = (np.arange(M, dtype='float32')/N*Fs/2)[:,None]
|
||||
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
|
||||
ERB_N = 24.7 + .108*in_freq
|
||||
delta = np.abs(in_freq-out_freq)/ERB_N
|
||||
center = (delta<.5).astype('float32')
|
||||
R = -12*center*delta**2 + (1-center)*(3-12*delta)
|
||||
RE = 10.**(R/10.)
|
||||
norm = np.sum(RE, axis=1)
|
||||
RE = RE/norm[:, np.newaxis]
|
||||
return torch.from_numpy(RE)
|
||||
|
||||
def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False):
|
||||
|
||||
f0 = 0
|
||||
num_bins = n_fft // 2 + 1
|
||||
f1 = fs / n_fft * (num_bins - 1)
|
||||
fstep = fs / n_fft
|
||||
|
||||
if scale == 'opus':
|
||||
bins_5ms = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40]
|
||||
fac = 1000 * n_fft / fs / 5
|
||||
if num_bands != 18:
|
||||
print("warning: requested Opus filter bank with num_bands != 18. Adjusting num_bands.")
|
||||
num_bands = 18
|
||||
center_bins = np.array([fac * bin for bin in bins_5ms])
|
||||
else:
|
||||
to_scale, from_scale = scale_dict[scale]
|
||||
|
||||
s0 = to_scale(f0)
|
||||
s1 = to_scale(f1)
|
||||
|
||||
center_freqs = np.array([f0] + [from_scale(s0 + i * (s1 - s0) / (num_bands)) for i in range(1, num_bands - 1)] + [f1])
|
||||
center_bins = (center_freqs - f0) / fstep
|
||||
|
||||
if round_center_bins:
|
||||
center_bins = np.round(center_bins)
|
||||
|
||||
filter_bank = np.zeros((num_bands, num_bins))
|
||||
|
||||
band = 0
|
||||
for bin in range(num_bins):
|
||||
# update band index
|
||||
if bin > center_bins[band + 1]:
|
||||
band += 1
|
||||
|
||||
# calculate filter coefficients
|
||||
frac = (center_bins[band + 1] - bin) / (center_bins[band + 1] - center_bins[band])
|
||||
filter_bank[band][bin] = frac
|
||||
filter_bank[band + 1][bin] = 1 - frac
|
||||
|
||||
if return_upper:
|
||||
extend = n_fft - num_bins
|
||||
filter_bank = np.concatenate((filter_bank, np.fliplr(filter_bank[:, 1:extend+1])), axis=1)
|
||||
|
||||
if normalize:
|
||||
filter_bank = filter_bank / np.sum(filter_bank, axis=1).reshape(-1, 1)
|
||||
|
||||
return filter_bank
|
||||
|
||||
|
||||
def compressed_log_spec(pspec):
|
||||
|
||||
lpspec = np.zeros_like(pspec)
|
||||
num_bands = pspec.shape[-1]
|
||||
|
||||
log_max = -2
|
||||
follow = -2
|
||||
|
||||
for i in range(num_bands):
|
||||
tmp = np.log10(pspec[i] + 1e-9)
|
||||
tmp = max(log_max, max(follow - 2.5, tmp))
|
||||
lpspec[i] = tmp
|
||||
log_max = max(log_max, tmp)
|
||||
follow = max(follow - 2.5, tmp)
|
||||
|
||||
return lpspec
|
||||
|
||||
def log_spectrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False, power=1):
|
||||
""" calculates cepstrum from SILK lpcs """
|
||||
order = a.shape[-1]
|
||||
assert order + 1 < n_fft
|
||||
|
||||
a = a * (gamma ** (1 + np.arange(order))).astype(np.float32)
|
||||
|
||||
x = np.zeros((*a.shape[:-1], n_fft ))
|
||||
x[..., 0] = 1
|
||||
x[..., 1:1 + order] = -a
|
||||
|
||||
X = np.fft.fft(x, axis=-1)
|
||||
X = np.abs(X[..., :n_fft//2 + 1]) ** power
|
||||
|
||||
S = 1 / (X + eps)
|
||||
|
||||
if fb is None:
|
||||
Sf = S
|
||||
else:
|
||||
Sf = np.matmul(S, fb.T)
|
||||
|
||||
if compress:
|
||||
Sf = np.apply_along_axis(compressed_log_spec, -1, Sf)
|
||||
else:
|
||||
Sf = np.log(Sf + eps)
|
||||
|
||||
return Sf
|
||||
|
||||
def cepstrum_from_lpc(a, fb=None, n_fft=320, eps=1e-9, gamma=1, compress=False):
|
||||
""" calculates cepstrum from SILK lpcs """
|
||||
|
||||
Sf = log_spectrum_from_lpc(a, fb, n_fft, eps, gamma, compress)
|
||||
|
||||
cepstrum = scipy.fftpack.dct(Sf, 2, norm='ortho')
|
||||
|
||||
return cepstrum
|
||||
|
||||
|
||||
|
||||
def log_spectrum(x, frame_size, fb=None, window=None, power=1):
|
||||
""" calculate cepstrum on 50% overlapping frames """
|
||||
|
||||
assert(2*len(x)) % frame_size == 0
|
||||
assert frame_size % 2 == 0
|
||||
|
||||
n = len(x)
|
||||
num_even = n // frame_size
|
||||
num_odd = (n - frame_size // 2) // frame_size
|
||||
num_bins = frame_size // 2 + 1
|
||||
|
||||
x_even = x[:num_even * frame_size].reshape(-1, frame_size)
|
||||
x_odd = x[frame_size//2 : frame_size//2 + frame_size * num_odd].reshape(-1, frame_size)
|
||||
|
||||
x_unfold = np.empty((x_even.size + x_odd.size), dtype=x.dtype).reshape((-1, frame_size))
|
||||
x_unfold[::2, :] = x_even
|
||||
x_unfold[1::2, :] = x_odd
|
||||
|
||||
if window is not None:
|
||||
x_unfold *= window.reshape(1, -1)
|
||||
|
||||
X = np.abs(np.fft.fft(x_unfold, n=frame_size, axis=-1))[:, :num_bins] ** power
|
||||
|
||||
if fb is not None:
|
||||
X = np.matmul(X, fb.T)
|
||||
|
||||
|
||||
return np.log(X + 1e-9)
|
||||
|
||||
|
||||
def cepstrum(x, frame_size, fb=None, window=None):
|
||||
""" calculate cepstrum on 50% overlapping frames """
|
||||
|
||||
X = log_spectrum(x, frame_size, fb, window)
|
||||
|
||||
cepstrum = scipy.fftpack.dct(X, 2, norm='ortho')
|
||||
|
||||
return cepstrum
|
||||
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
setup_dict = dict()
|
||||
|
||||
lace_setup = {
|
||||
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
|
||||
'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
|
||||
'model': {
|
||||
'name': 'lace',
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'comb_gain_limit_db': 10,
|
||||
'cond_dim': 128,
|
||||
'conv_gain_limits_db': [-12, 12],
|
||||
'global_gain_limits_db': [-6, 6],
|
||||
'hidden_feature_dim': 96,
|
||||
'kernel_size': 15,
|
||||
'num_features': 93,
|
||||
'numbits_embedding_dim': 8,
|
||||
'numbits_range': [50, 650],
|
||||
'partial_lookahead': True,
|
||||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'softquant': True,
|
||||
'sparsify': False,
|
||||
'sparsification_density': 0.4,
|
||||
'sparsification_schedule': [10000, 40000, 200]
|
||||
}
|
||||
},
|
||||
'data': {
|
||||
'frames_per_sample': 100,
|
||||
'no_pitch_value': 7,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'pitch_hangover': 8,
|
||||
'acorr_radius': 2,
|
||||
'num_bands_clean_spec': 64,
|
||||
'num_bands_noisy_spec': 18,
|
||||
'noisy_spec_scale': 'opus',
|
||||
'pitch_hangover': 0,
|
||||
},
|
||||
'training': {
|
||||
'batch_size': 256,
|
||||
'lr': 5.e-4,
|
||||
'lr_decay_factor': 2.5e-5,
|
||||
'epochs': 50,
|
||||
'loss': {
|
||||
'w_l1': 0,
|
||||
'w_lm': 0,
|
||||
'w_logmel': 0,
|
||||
'w_sc': 0,
|
||||
'w_wsc': 0,
|
||||
'w_xcorr': 0,
|
||||
'w_sxcorr': 1,
|
||||
'w_l2': 10,
|
||||
'w_slm': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
nolace_setup = {
|
||||
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
|
||||
'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation',
|
||||
'model': {
|
||||
'name': 'nolace',
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'avg_pool_k': 4,
|
||||
'comb_gain_limit_db': 10,
|
||||
'cond_dim': 256,
|
||||
'conv_gain_limits_db': [-12, 12],
|
||||
'global_gain_limits_db': [-6, 6],
|
||||
'hidden_feature_dim': 96,
|
||||
'kernel_size': 15,
|
||||
'num_features': 93,
|
||||
'numbits_embedding_dim': 8,
|
||||
'numbits_range': [50, 650],
|
||||
'partial_lookahead': True,
|
||||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'softquant': True,
|
||||
'sparsify': False,
|
||||
'sparsification_density': 0.4,
|
||||
'sparsification_schedule': [10000, 40000, 200]
|
||||
}
|
||||
},
|
||||
'data': {
|
||||
'frames_per_sample': 100,
|
||||
'no_pitch_value': 7,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'pitch_hangover': 8,
|
||||
'acorr_radius': 2,
|
||||
'num_bands_clean_spec': 64,
|
||||
'num_bands_noisy_spec': 18,
|
||||
'noisy_spec_scale': 'opus',
|
||||
'pitch_hangover': 0,
|
||||
},
|
||||
'training': {
|
||||
'batch_size': 256,
|
||||
'lr': 5.e-4,
|
||||
'lr_decay_factor': 2.5e-5,
|
||||
'epochs': 50,
|
||||
'loss': {
|
||||
'w_l1': 0,
|
||||
'w_lm': 0,
|
||||
'w_logmel': 0,
|
||||
'w_sc': 0,
|
||||
'w_wsc': 0,
|
||||
'w_xcorr': 0,
|
||||
'w_sxcorr': 1,
|
||||
'w_l2': 10,
|
||||
'w_slm': 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nolace_setup_adv = {
|
||||
'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training',
|
||||
'model': {
|
||||
'name': 'nolace',
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'avg_pool_k': 4,
|
||||
'comb_gain_limit_db': 10,
|
||||
'cond_dim': 256,
|
||||
'conv_gain_limits_db': [-12, 12],
|
||||
'global_gain_limits_db': [-6, 6],
|
||||
'hidden_feature_dim': 96,
|
||||
'kernel_size': 15,
|
||||
'num_features': 93,
|
||||
'numbits_embedding_dim': 8,
|
||||
'numbits_range': [50, 650],
|
||||
'partial_lookahead': True,
|
||||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'softquant': True,
|
||||
'sparsify': False,
|
||||
'sparsification_density': 0.4,
|
||||
'sparsification_schedule': [0, 0, 200]
|
||||
}
|
||||
},
|
||||
'data': {
|
||||
'frames_per_sample': 100,
|
||||
'no_pitch_value': 7,
|
||||
'preemph': 0.85,
|
||||
'skip': 91,
|
||||
'pitch_hangover': 8,
|
||||
'acorr_radius': 2,
|
||||
'num_bands_clean_spec': 64,
|
||||
'num_bands_noisy_spec': 18,
|
||||
'noisy_spec_scale': 'opus',
|
||||
'pitch_hangover': 0,
|
||||
},
|
||||
'discriminator': {
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'architecture': 'free',
|
||||
'design': 'f_down',
|
||||
'fft_sizes_16k': [
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
],
|
||||
'freq_roi': [0, 7400],
|
||||
'fs': 16000,
|
||||
'max_channels': 256,
|
||||
'noise_gain': 0.0,
|
||||
},
|
||||
'name': 'fdmresdisc',
|
||||
},
|
||||
'training': {
|
||||
'adv_target': 'target_orig',
|
||||
'batch_size': 64,
|
||||
'epochs': 50,
|
||||
'gen_lr_reduction': 1,
|
||||
'lambda_feat': 1.0,
|
||||
'lambda_reg': 0.6,
|
||||
'loss': {
|
||||
'w_l1': 0,
|
||||
'w_l2': 10,
|
||||
'w_lm': 0,
|
||||
'w_logmel': 0,
|
||||
'w_sc': 0,
|
||||
'w_slm': 20,
|
||||
'w_sxcorr': 1,
|
||||
'w_wsc': 0,
|
||||
'w_xcorr': 0,
|
||||
},
|
||||
'lr': 0.0001,
|
||||
'lr_decay_factor': 2.5e-09,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
lavoce_setup = {
|
||||
'data': {
|
||||
'frames_per_sample': 100,
|
||||
'target': 'signal'
|
||||
},
|
||||
'dataset': '/local/datasets/lpcnet_large/training',
|
||||
'model': {
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'comb_gain_limit_db': 10,
|
||||
'cond_dim': 256,
|
||||
'conv_gain_limits_db': [-12, 12],
|
||||
'global_gain_limits_db': [-6, 6],
|
||||
'kernel_size': 15,
|
||||
'num_features': 19,
|
||||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'pulses': True
|
||||
},
|
||||
'name': 'lavoce'
|
||||
},
|
||||
'training': {
|
||||
'batch_size': 256,
|
||||
'epochs': 50,
|
||||
'loss': {
|
||||
'w_l1': 0,
|
||||
'w_l2': 0,
|
||||
'w_lm': 0,
|
||||
'w_logmel': 0,
|
||||
'w_sc': 0,
|
||||
'w_slm': 2,
|
||||
'w_sxcorr': 1,
|
||||
'w_wsc': 0,
|
||||
'w_xcorr': 0
|
||||
},
|
||||
'lr': 0.0005,
|
||||
'lr_decay_factor': 2.5e-05
|
||||
},
|
||||
'validation_dataset': '/local/datasets/lpcnet_large/validation'
|
||||
}
|
||||
|
||||
lavoce_setup_adv = {
|
||||
'data': {
|
||||
'frames_per_sample': 100,
|
||||
'target': 'signal'
|
||||
},
|
||||
'dataset': '/local/datasets/lpcnet_large/training',
|
||||
'discriminator': {
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'architecture': 'free',
|
||||
'design': 'f_down',
|
||||
'fft_sizes_16k': [
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
],
|
||||
'freq_roi': [0, 7400],
|
||||
'fs': 16000,
|
||||
'max_channels': 256,
|
||||
'noise_gain': 0.0,
|
||||
},
|
||||
'name': 'fdmresdisc',
|
||||
},
|
||||
'model': {
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'comb_gain_limit_db': 10,
|
||||
'cond_dim': 256,
|
||||
'conv_gain_limits_db': [-12, 12],
|
||||
'global_gain_limits_db': [-6, 6],
|
||||
'kernel_size': 15,
|
||||
'num_features': 19,
|
||||
'pitch_embedding_dim': 64,
|
||||
'pitch_max': 300,
|
||||
'preemph': 0.85,
|
||||
'pulses': True
|
||||
},
|
||||
'name': 'lavoce'
|
||||
},
|
||||
'training': {
|
||||
'batch_size': 64,
|
||||
'epochs': 50,
|
||||
'gen_lr_reduction': 1,
|
||||
'lambda_feat': 1.0,
|
||||
'lambda_reg': 0.6,
|
||||
'loss': {
|
||||
'w_l1': 0,
|
||||
'w_l2': 0,
|
||||
'w_lm': 0,
|
||||
'w_logmel': 0,
|
||||
'w_sc': 0,
|
||||
'w_slm': 2,
|
||||
'w_sxcorr': 1,
|
||||
'w_wsc': 0,
|
||||
'w_xcorr': 0
|
||||
},
|
||||
'lr': 0.0001,
|
||||
'lr_decay_factor': 2.5e-09
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
setup_dict = {
|
||||
'lace': lace_setup,
|
||||
'nolace': nolace_setup,
|
||||
'nolace_adv': nolace_setup_adv,
|
||||
'lavoce': lavoce_setup,
|
||||
'lavoce_adv': lavoce_setup_adv
|
||||
}
|
||||
Reference in New Issue
Block a user