add some code
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
from . import sparsification
|
||||
from . import data
|
||||
from . import pcm
|
||||
from . import sample
|
||||
141
managed_components/78__esp-opus/dnn/torch/lpcnet/utils/data.py
Normal file
141
managed_components/78__esp-opus/dnn/torch/lpcnet/utils/data.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def load_features(feature_file, version=2):
|
||||
if version == 2:
|
||||
layout = {
|
||||
'cepstrum': [0,18],
|
||||
'periods': [18, 19],
|
||||
'pitch_corr': [19, 20],
|
||||
'lpc': [20, 36]
|
||||
}
|
||||
frame_length = 36
|
||||
|
||||
elif version == 1:
|
||||
layout = {
|
||||
'cepstrum': [0,18],
|
||||
'periods': [36, 37],
|
||||
'pitch_corr': [37, 38],
|
||||
'lpc': [39, 55],
|
||||
}
|
||||
frame_length = 55
|
||||
else:
|
||||
raise ValueError(f'unknown feature version: {version}')
|
||||
|
||||
|
||||
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
|
||||
raw_features = raw_features.reshape((-1, frame_length))
|
||||
|
||||
features = torch.cat(
|
||||
[
|
||||
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
|
||||
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
|
||||
],
|
||||
dim=1
|
||||
)
|
||||
|
||||
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
|
||||
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
|
||||
|
||||
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
|
||||
|
||||
|
||||
|
||||
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
|
||||
ref_data = np.memmap(reference_data_path, dtype=np.int16)
|
||||
signal = np.memmap(signal_path, dtype=np.int16)
|
||||
|
||||
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
|
||||
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
|
||||
|
||||
|
||||
assert len(signal) % 160 == 0
|
||||
num_frames = len(signal) // 160
|
||||
mem = np.zeros(1)
|
||||
for fr in range(len(signal)//160):
|
||||
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
|
||||
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
|
||||
|
||||
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
|
||||
|
||||
new_data[:] = 0
|
||||
N = len(signal) - offset
|
||||
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
|
||||
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
|
||||
|
||||
|
||||
def parse_warpq_scores(output_file):
|
||||
""" extracts warpq scores from output file """
|
||||
|
||||
with open(output_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def parse_stats_file(file):
|
||||
|
||||
with open(file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
mean = float(lines[0].split(":")[-1])
|
||||
bt_mean = float(lines[1].split(":")[-1])
|
||||
top_mean = float(lines[2].split(":")[-1])
|
||||
|
||||
return mean, bt_mean, top_mean
|
||||
|
||||
def collect_test_stats(test_folder):
|
||||
""" collects statistics for all discovered metrics from test folder """
|
||||
|
||||
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
|
||||
|
||||
results = dict()
|
||||
|
||||
content = os.listdir(test_folder)
|
||||
|
||||
stats_files = [file for file in content if file.startswith('stats_')]
|
||||
|
||||
for file in stats_files:
|
||||
metric = file[len("stats_") : -len(".txt")]
|
||||
|
||||
if metric not in metrics:
|
||||
print(f"warning: unknown metric {metric}")
|
||||
|
||||
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
|
||||
|
||||
results[metric] = [mean, bt_mean, top_mean]
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" module for inspecting models during inference """
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation as animation
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
|
||||
_state = dict()
|
||||
_folder = 'endoscopy'
|
||||
|
||||
def get_gru_gates(gru, input, state):
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
|
||||
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
|
||||
|
||||
# reset gate
|
||||
start, stop = 0 * hidden_size, 1 * hidden_size
|
||||
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# update gate
|
||||
start, stop = 1 * hidden_size, 2 * hidden_size
|
||||
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# new gate
|
||||
start, stop = 2 * hidden_size, 3 * hidden_size
|
||||
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
|
||||
|
||||
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
|
||||
|
||||
|
||||
def init(folder='endoscopy'):
|
||||
""" sets up output folder for endoscopy data """
|
||||
|
||||
global _folder
|
||||
_folder = folder
|
||||
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
else:
|
||||
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
|
||||
|
||||
def write_data(key, data, fs):
|
||||
""" appends data to previous data written under key """
|
||||
|
||||
global _state
|
||||
|
||||
# convert to numpy if torch.Tensor is given
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.detach().numpy()
|
||||
|
||||
if not key in _state:
|
||||
_state[key] = {
|
||||
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
|
||||
'fs' : fs,
|
||||
'dim' : tuple(data.shape),
|
||||
'dtype' : str(data.dtype)
|
||||
}
|
||||
|
||||
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
|
||||
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
|
||||
else:
|
||||
if _state[key]['fs'] != fs:
|
||||
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
|
||||
if _state[key]['dtype'] != str(data.dtype):
|
||||
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
|
||||
if _state[key]['dim'] != tuple(data.shape):
|
||||
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
|
||||
|
||||
_state[key]['fid'].write(data.tobytes())
|
||||
|
||||
def close(folder='endoscopy'):
|
||||
""" clean up """
|
||||
for key in _state.keys():
|
||||
_state[key]['fid'].close()
|
||||
|
||||
|
||||
def read_data(folder='endoscopy'):
|
||||
""" retrieves written data as numpy arrays """
|
||||
|
||||
|
||||
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
|
||||
|
||||
return_dict = dict()
|
||||
|
||||
for key in keys:
|
||||
with open(os.path.join(folder, key + '.yml'), 'r') as f:
|
||||
value = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
|
||||
data = np.frombuffer(f.read(), dtype=value['dtype'])
|
||||
|
||||
value['data'] = data.reshape((-1,) + value['dim'])
|
||||
|
||||
return_dict[key] = value
|
||||
|
||||
return return_dict
|
||||
|
||||
def get_best_reshape(shape, target_ratio=1):
|
||||
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
|
||||
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return (1,)
|
||||
|
||||
num_columns = int((pixel_count / target_ratio)**.5)
|
||||
|
||||
while (pixel_count % num_columns):
|
||||
num_columns -= 1
|
||||
|
||||
num_rows = pixel_count // num_columns
|
||||
|
||||
return (num_rows, num_columns)
|
||||
|
||||
def get_type_and_shape(shape):
|
||||
|
||||
# can happen if data is one dimensional
|
||||
if len(shape) == 0:
|
||||
shape = (1,)
|
||||
|
||||
# calculate pixel count
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return 'plot', (1, )
|
||||
|
||||
# stay with shape if already 2-dimensional
|
||||
if len(shape) == 2:
|
||||
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
|
||||
return 'image', shape
|
||||
|
||||
return 'image', get_best_reshape(shape)
|
||||
|
||||
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
|
||||
|
||||
# determine plot setup
|
||||
num_keys = len(data.keys())
|
||||
|
||||
num_rows = int((num_keys * 3/4) ** .5)
|
||||
|
||||
num_cols = (num_keys + num_rows - 1) // num_rows
|
||||
|
||||
fig, axs = plt.subplots(num_rows, num_cols)
|
||||
fig.set_size_inches(num_cols * 5, num_rows * 5)
|
||||
|
||||
display = dict()
|
||||
|
||||
fs_max = max([val['fs'] for val in data.values()])
|
||||
|
||||
num_samples = max([val['data'].shape[0] for val in data.values()])
|
||||
|
||||
keys = sorted(data.keys())
|
||||
|
||||
# inspect data
|
||||
for i, key in enumerate(keys):
|
||||
axs[i // num_cols, i % num_cols].title.set_text(key)
|
||||
|
||||
display[key] = dict()
|
||||
|
||||
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
|
||||
display[key]['down_factor'] = data[key]['fs'] / fs_max
|
||||
|
||||
start_index = max(start_index, half_signal_window_length)
|
||||
while stop_index < 0:
|
||||
stop_index += num_samples
|
||||
|
||||
stop_index = min(stop_index, num_samples - half_signal_window_length)
|
||||
|
||||
# actual plotting
|
||||
frames = []
|
||||
for index in range(start_index, stop_index):
|
||||
ims = []
|
||||
for i, key in enumerate(keys):
|
||||
feature_index = int(round(index * display[key]['down_factor']))
|
||||
|
||||
if display[key]['type'] == 'plot':
|
||||
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
|
||||
|
||||
elif display[key]['type'] == 'image':
|
||||
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
|
||||
|
||||
frames.append(ims)
|
||||
|
||||
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
|
||||
|
||||
if not filename.endswith('.mp4'):
|
||||
filename += '.mp4'
|
||||
|
||||
ani.save(filename)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .dual_fc import DualFC
|
||||
from .subconditioner import AdditiveSubconditioner, ModulativeSubconditioner, ConcatenativeSubconditioner
|
||||
from .pcm_embeddings import PCMEmbedding, DifferentiablePCMEmbedding
|
||||
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class DualFC(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(DualFC, self).__init__()
|
||||
|
||||
self.dense1 = nn.Linear(input_dim, output_dim)
|
||||
self.dense2 = nn.Linear(input_dim, output_dim)
|
||||
|
||||
self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
|
||||
self.beta = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.alpha * torch.tanh(self.dense1(x)) + self.beta * torch.tanh(self.dense2(x))
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" module implementing PCM embeddings for LPCNet """
|
||||
|
||||
import math as m
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PCMEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim=128, num_levels=256):
|
||||
super(PCMEmbedding, self).__init__()
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.num_levels = num_levels
|
||||
|
||||
self.embedding = nn.Embedding(self.num_levels, self.num_dim)
|
||||
|
||||
# initialize
|
||||
with torch.no_grad():
|
||||
num_rows, num_cols = self.num_levels, self.embed_dim
|
||||
a = m.sqrt(12) * (torch.rand(num_rows, num_cols) - 0.5)
|
||||
for i in range(num_rows):
|
||||
a[i, :] += m.sqrt(12) * (i - num_rows / 2)
|
||||
self.embedding.weight[:, :] = 0.1 * a
|
||||
|
||||
def forward(self, x):
|
||||
return self.embeddint(x)
|
||||
|
||||
|
||||
class DifferentiablePCMEmbedding(PCMEmbedding):
|
||||
def __init__(self, embed_dim, num_levels=256):
|
||||
super(DifferentiablePCMEmbedding, self).__init__(embed_dim, num_levels)
|
||||
|
||||
def forward(self, x):
|
||||
x_int = (x - torch.floor(x)).detach().long()
|
||||
x_frac = x - x_int
|
||||
x_next = torch.minimum(x_int + 1, self.num_levels)
|
||||
|
||||
embed_0 = self.embedding(x_int)
|
||||
embed_1 = self.embedding(x_next)
|
||||
|
||||
return (1 - x_frac) * embed_0 + x_frac * embed_1
|
||||
@@ -0,0 +1,497 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
from re import sub
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
||||
|
||||
def get_subconditioner( method,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
**kwargs):
|
||||
|
||||
subconditioner_dict = {
|
||||
'additive' : AdditiveSubconditioner,
|
||||
'concatenative' : ConcatenativeSubconditioner,
|
||||
'modulative' : ModulativeSubconditioner
|
||||
}
|
||||
|
||||
return subconditioner_dict[method](number_of_subsamples,
|
||||
pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
|
||||
|
||||
|
||||
class Subconditioner(nn.Module):
|
||||
def __init__(self):
|
||||
""" upsampling by subconditioning
|
||||
|
||||
Upsamples a sequence of states conditioning on pcm signals and
|
||||
optionally a feature vector.
|
||||
"""
|
||||
super(Subconditioner, self).__init__()
|
||||
|
||||
def forward(self, states, signals, features=None):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
def single_step(self, index, state, signals, features):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
def get_output_dim(self, index):
|
||||
raise Exception("Base class should not be called")
|
||||
|
||||
|
||||
class AdditiveSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
**kwargs):
|
||||
""" subconditioning by addition """
|
||||
|
||||
super(AdditiveSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
|
||||
if self.pcm_embedding_size != self.state_size:
|
||||
raise ValueError('For additive subconditioning state and embedding '
|
||||
+ f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
|
||||
|
||||
self.embeddings = [None]
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.sum(embed, dim=2)
|
||||
|
||||
new_states = new_states + embed
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.sum(embed_signals, dim=-2)
|
||||
c_state = state + c
|
||||
|
||||
return c_state
|
||||
|
||||
def get_output_dim(self, index):
|
||||
return self.state_size
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
s = self.number_of_subsamples
|
||||
flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
|
||||
return flops
|
||||
|
||||
|
||||
class ConcatenativeSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
recurrent=True,
|
||||
**kwargs):
|
||||
""" subconditioning by concatenation """
|
||||
|
||||
super(ConcatenativeSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
self.recurrent = recurrent
|
||||
|
||||
self.embeddings = []
|
||||
start_index = 0
|
||||
if self.recurrent:
|
||||
start_index = 1
|
||||
self.embeddings.append(None)
|
||||
|
||||
for i in range(start_index, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
s = self.number_of_subsamples
|
||||
|
||||
if self.recurrent:
|
||||
c_states = [states]
|
||||
start = 1
|
||||
else:
|
||||
c_states = []
|
||||
start = 0
|
||||
|
||||
new_states = states
|
||||
for i in range(start, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
if self.recurrent:
|
||||
new_states = torch.cat((new_states, embed), dim=-1)
|
||||
else:
|
||||
new_states = torch.cat((states, embed), dim=-1)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0 and self.recurrent:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.flatten(embed_signals, -2)
|
||||
if not self.recurrent and index > 0:
|
||||
# overwrite previous conditioning vector
|
||||
c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
|
||||
else:
|
||||
c_state = torch.cat((state, c), dim=-1)
|
||||
return c_state
|
||||
|
||||
return c_state
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
return 0
|
||||
|
||||
def get_output_dim(self, index):
|
||||
if self.recurrent:
|
||||
return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
|
||||
else:
|
||||
return self.state_size + self.pcm_embedding_size * self.number_of_signals
|
||||
|
||||
class ModulativeSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
state_recurrent=False,
|
||||
**kwargs):
|
||||
""" subconditioning by modulation """
|
||||
|
||||
super(ModulativeSubconditioner, self).__init__()
|
||||
|
||||
self.number_of_subsamples = number_of_subsamples
|
||||
self.pcm_embedding_size = pcm_embedding_size
|
||||
self.state_size = state_size
|
||||
self.pcm_levels = pcm_levels
|
||||
self.number_of_signals = number_of_signals
|
||||
self.state_recurrent = state_recurrent
|
||||
|
||||
self.hidden_size = self.pcm_embedding_size * self.number_of_signals
|
||||
|
||||
if self.state_recurrent:
|
||||
self.hidden_size += self.pcm_embedding_size
|
||||
self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
|
||||
|
||||
self.embeddings = [None]
|
||||
self.alphas = [None]
|
||||
self.betas = [None]
|
||||
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
|
||||
self.add_module('alpha_dense_' + str(i), self.alphas[-1])
|
||||
|
||||
self.betas.append(nn.Linear(self.hidden_size, self.state_size))
|
||||
self.add_module('beta_dense_' + str(i), self.betas[-1])
|
||||
|
||||
|
||||
|
||||
def forward(self, states, signals):
|
||||
""" creates list of subconditioned states
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
states : torch.tensor
|
||||
states of shape (batch, seq_length // s, state_size)
|
||||
signals : torch.tensor
|
||||
signals of shape (batch, seq_length, number_of_signals)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
c_states : list of torch.tensor
|
||||
list of s subconditioned states
|
||||
"""
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
if self.state_recurrent:
|
||||
comp_states = self.state_transform(new_states)
|
||||
embed = torch.cat((embed, comp_states), dim=-1)
|
||||
|
||||
alpha = torch.tanh(self.alphas[i](embed))
|
||||
beta = torch.tanh(self.betas[i](embed))
|
||||
|
||||
# new state obtained by modulating previous state
|
||||
new_states = torch.tanh((1 + alpha) * new_states + beta)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
|
||||
def single_step(self, index, state, signals):
|
||||
""" carry out single step for inference
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
index : int
|
||||
position in subconditioning batch
|
||||
|
||||
state : torch.tensor
|
||||
state to sub-condition
|
||||
|
||||
signals : torch.tensor
|
||||
signals for subconditioning, all but the last dimensions
|
||||
must match those of state
|
||||
|
||||
Returns:
|
||||
c_state : torch.tensor
|
||||
subconditioned state
|
||||
"""
|
||||
|
||||
if index == 0:
|
||||
c_state = state
|
||||
else:
|
||||
embed_signals = self.embeddings[index](signals)
|
||||
c = torch.flatten(embed_signals, -2)
|
||||
if self.state_recurrent:
|
||||
r_state = self.state_transform(state)
|
||||
c = torch.cat((c, r_state), dim=-1)
|
||||
alpha = torch.tanh(self.alphas[index](c))
|
||||
beta = torch.tanh(self.betas[index](c))
|
||||
c_state = torch.tanh((1 + alpha) * state + beta)
|
||||
return c_state
|
||||
|
||||
return c_state
|
||||
|
||||
def get_output_dim(self, index):
|
||||
return self.state_size
|
||||
|
||||
def get_average_flops_per_step(self):
|
||||
s = self.number_of_subsamples
|
||||
|
||||
# estimate activation by 10 flops
|
||||
# c_state = torch.tanh((1 + alpha) * state + beta)
|
||||
flops = 13 * self.state_size
|
||||
|
||||
# hidden size
|
||||
hidden_size = self.number_of_signals * self.pcm_embedding_size
|
||||
if self.state_recurrent:
|
||||
hidden_size += self.pcm_embedding_size
|
||||
|
||||
# counting 2 * A * B flops for Linear(A, B)
|
||||
# alpha = torch.tanh(self.alphas[index](c))
|
||||
# beta = torch.tanh(self.betas[index](c))
|
||||
flops += 4 * hidden_size * self.state_size + 20 * self.state_size
|
||||
|
||||
# r_state = self.state_transform(state)
|
||||
if self.state_recurrent:
|
||||
flops += 2 * self.state_size * self.pcm_embedding_size
|
||||
|
||||
# average over steps
|
||||
flops *= (s - 1) / s
|
||||
|
||||
return flops
|
||||
|
||||
class ComparitiveSubconditioner(Subconditioner):
|
||||
def __init__(self,
|
||||
number_of_subsamples,
|
||||
pcm_embedding_size,
|
||||
state_size,
|
||||
pcm_levels,
|
||||
number_of_signals,
|
||||
error_index=-1,
|
||||
apply_gate=True,
|
||||
normalize=False):
|
||||
""" subconditioning by comparison """
|
||||
|
||||
super(ComparitiveSubconditioner, self).__init__()
|
||||
|
||||
self.comparison_size = self.pcm_embedding_size
|
||||
self.error_position = error_index
|
||||
self.apply_gate = apply_gate
|
||||
self.normalize = normalize
|
||||
|
||||
self.state_transform = nn.Linear(self.state_size, self.comparison_size)
|
||||
|
||||
self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
|
||||
self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
|
||||
|
||||
if self.apply_gate:
|
||||
self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
|
||||
|
||||
# embeddings and state transforms
|
||||
self.embeddings = [None]
|
||||
self.alpha_denses = [None]
|
||||
self.beta_denses = [None]
|
||||
self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
|
||||
self.add_module('state_transform_0', self.state_transforms[0])
|
||||
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
|
||||
self.add_module('pcm_embedding_' + str(i), embedding)
|
||||
self.embeddings.append(embedding)
|
||||
|
||||
state_transform = nn.Linear(self.state_size, self.comparison_size)
|
||||
self.add_module('state_transform_' + str(i), state_transform)
|
||||
self.state_transforms.append(state_transform)
|
||||
|
||||
self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
|
||||
self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
|
||||
|
||||
self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
|
||||
self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
|
||||
|
||||
def forward(self, states, signals):
|
||||
s = self.number_of_subsamples
|
||||
|
||||
c_states = [states]
|
||||
new_states = states
|
||||
for i in range(1, self.number_of_subsamples):
|
||||
embed = self.embeddings[i](signals[:, i::s])
|
||||
# reduce signal dimension
|
||||
embed = torch.flatten(embed, -2)
|
||||
|
||||
comp_states = self.state_transforms[i](new_states)
|
||||
|
||||
alpha = torch.tanh(self.alpha_dense(embed))
|
||||
beta = torch.tanh(self.beta_dense(embed))
|
||||
|
||||
# new state obtained by modulating previous state
|
||||
new_states = torch.tanh((1 + alpha) * comp_states + beta)
|
||||
|
||||
c_states.append(new_states)
|
||||
|
||||
return c_states
|
||||
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def find(a, v):
|
||||
try:
|
||||
idx = a.index(v)
|
||||
except:
|
||||
idx = -1
|
||||
return idx
|
||||
|
||||
def interleave_tensors(tensors, dim=-2):
|
||||
""" interleave list of tensors along sequence dimension """
|
||||
|
||||
x = torch.cat([x.unsqueeze(dim) for x in tensors], dim=dim)
|
||||
x = torch.flatten(x, dim - 1, dim)
|
||||
|
||||
return x
|
||||
|
||||
def _interleave(x, pcm_levels=256):
|
||||
|
||||
repeats = pcm_levels // (2*x.size(-1))
|
||||
x = x.unsqueeze(-1)
|
||||
p = torch.flatten(torch.repeat_interleave(torch.cat((x, 1 - x), dim=-1), repeats, dim=-1), -2)
|
||||
|
||||
return p
|
||||
|
||||
def get_pdf_from_tree(x):
|
||||
pcm_levels = x.size(-1)
|
||||
|
||||
p = _interleave(x[..., 1:2])
|
||||
n = 4
|
||||
while n <= pcm_levels:
|
||||
p = p * _interleave(x[..., n//2:n])
|
||||
n *= 2
|
||||
|
||||
return p
|
||||
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
|
||||
def clip_to_int16(x):
|
||||
int_min = -2**15
|
||||
int_max = 2**15 - 1
|
||||
x_clipped = max(int_min, min(x, int_max))
|
||||
return x_clipped
|
||||
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def sample_excitation(probs, pitch_corr):
|
||||
|
||||
norm = lambda x : x / (x.sum() + 1e-18)
|
||||
|
||||
# lowering the temperature
|
||||
probs = norm(probs ** (1 + max(0, 1.5 * pitch_corr - 0.5)))
|
||||
# cut-off tails
|
||||
probs = norm(torch.maximum(probs - 0.002 , torch.FloatTensor([0])))
|
||||
# sample
|
||||
exc = torch.multinomial(probs.squeeze(), 1)
|
||||
|
||||
return exc
|
||||
@@ -0,0 +1,2 @@
|
||||
from .gru_sparsifier import GRUSparsifier
|
||||
from .common import sparsify_matrix, calculate_gru_flops_per_step
|
||||
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
|
||||
""" sparsifies matrix with specified block size
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
matrix : torch.tensor
|
||||
matrix to sparsify
|
||||
density : int
|
||||
target density
|
||||
block_size : [int, int]
|
||||
block size dimensions
|
||||
keep_diagonal : bool
|
||||
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
|
||||
"""
|
||||
|
||||
m, n = matrix.shape
|
||||
m1, n1 = block_size
|
||||
|
||||
if m % m1 or n % n1:
|
||||
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
|
||||
|
||||
# extract diagonal if keep_diagonal = True
|
||||
if keep_diagonal:
|
||||
if m != n:
|
||||
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
|
||||
|
||||
to_spare = torch.diag(torch.diag(matrix))
|
||||
matrix = matrix - to_spare
|
||||
else:
|
||||
to_spare = torch.zeros_like(matrix)
|
||||
|
||||
# calculate energy in sub-blocks
|
||||
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
|
||||
x = x ** 2
|
||||
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
|
||||
|
||||
number_of_blocks = (m * n) // (m1 * n1)
|
||||
number_of_survivors = round(number_of_blocks * density)
|
||||
|
||||
# masking threshold
|
||||
if number_of_survivors == 0:
|
||||
threshold = 0
|
||||
else:
|
||||
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
|
||||
|
||||
# create mask
|
||||
mask = torch.ones_like(block_energies)
|
||||
mask[block_energies < threshold] = 0
|
||||
mask = torch.repeat_interleave(mask, m1, dim=0)
|
||||
mask = torch.repeat_interleave(mask, n1, dim=1)
|
||||
|
||||
# perform masking
|
||||
masked_matrix = mask * matrix + to_spare
|
||||
|
||||
if return_mask:
|
||||
return masked_matrix, mask
|
||||
else:
|
||||
return masked_matrix
|
||||
|
||||
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
|
||||
input_size = gru.input_size
|
||||
hidden_size = gru.hidden_size
|
||||
flops = 0
|
||||
|
||||
input_density = (
|
||||
sparsification_dict.get('W_ir', [1])[0]
|
||||
+ sparsification_dict.get('W_in', [1])[0]
|
||||
+ sparsification_dict.get('W_iz', [1])[0]
|
||||
) / 3
|
||||
|
||||
recurrent_density = (
|
||||
sparsification_dict.get('W_hr', [1])[0]
|
||||
+ sparsification_dict.get('W_hn', [1])[0]
|
||||
+ sparsification_dict.get('W_hz', [1])[0]
|
||||
) / 3
|
||||
|
||||
# input matrix vector multiplications
|
||||
if not drop_input:
|
||||
flops += 2 * 3 * input_size * hidden_size * input_density
|
||||
|
||||
# recurrent matrix vector multiplications
|
||||
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
|
||||
|
||||
# biases
|
||||
flops += 6 * hidden_size
|
||||
|
||||
# activations estimated by 10 flops per activation
|
||||
flops += 30 * hidden_size
|
||||
|
||||
return flops
|
||||
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .common import sparsify_matrix
|
||||
|
||||
|
||||
class GRUSparsifier:
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
|
||||
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
|
||||
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
|
||||
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
|
||||
should be kept.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> gru = torch.nn.GRU(10, 20)
|
||||
>>> sparsify_dict = {
|
||||
... 'W_ir' : (0.5, [2, 2], False),
|
||||
... 'W_iz' : (0.6, [2, 2], False),
|
||||
... 'W_in' : (0.7, [2, 2], False),
|
||||
... 'W_hr' : (0.1, [4, 4], True),
|
||||
... 'W_hz' : (0.2, [4, 4], True),
|
||||
... 'W_hn' : (0.3, [4, 4], True),
|
||||
... }
|
||||
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
# just copying parameters...
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.interval = interval
|
||||
self.exponent = exponent
|
||||
self.task_list = task_list
|
||||
|
||||
# ... and setting counter to 0
|
||||
self.step_counter = 0
|
||||
|
||||
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
|
||||
|
||||
def step(self, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
# compute current interpolation factor
|
||||
self.step_counter += 1
|
||||
|
||||
if self.step_counter < self.start:
|
||||
return
|
||||
elif self.step_counter < self.stop:
|
||||
# update only every self.interval-th interval
|
||||
if self.step_counter % self.interval:
|
||||
return
|
||||
|
||||
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
|
||||
else:
|
||||
alpha = 0
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
for gru, params in self.task_list:
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
# input weights
|
||||
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
|
||||
if key in params:
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
|
||||
gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density, # density
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
|
||||
print(f"sparsification mask {key} changed for gru {gru}")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
# recurrent weights
|
||||
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
|
||||
if key in params:
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density,
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
|
||||
print(f"sparsification mask {key} changed for gru {gru}")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
gru = torch.nn.GRU(10, 20)
|
||||
sparsify_dict = {
|
||||
'W_ir' : (0.5, [2, 2], False),
|
||||
'W_iz' : (0.6, [2, 2], False),
|
||||
'W_in' : (0.7, [2, 2], False),
|
||||
'W_hr' : (0.1, [4, 4], True),
|
||||
'W_hz' : (0.2, [4, 4], True),
|
||||
'W_hn' : (0.3, [4, 4], True),
|
||||
}
|
||||
|
||||
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
from models import multi_rate_lpcnet
|
||||
import copy
|
||||
|
||||
setup_dict = dict()
|
||||
|
||||
dataset_template_v2 = {
|
||||
'version' : 2,
|
||||
'feature_file' : 'features.f32',
|
||||
'signal_file' : 'data.s16',
|
||||
'frame_length' : 160,
|
||||
'feature_frame_length' : 36,
|
||||
'signal_frame_length' : 2,
|
||||
'feature_dtype' : 'float32',
|
||||
'signal_dtype' : 'int16',
|
||||
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [18, 19], 'pitch_corr': [19, 20], 'lpc': [20, 36]},
|
||||
'signal_frame_layout' : {'last_signal' : 0, 'signal': 1} # signal, last_signal, error, prediction
|
||||
}
|
||||
|
||||
dataset_template_v1 = {
|
||||
'version' : 1,
|
||||
'feature_file' : 'features.f32',
|
||||
'signal_file' : 'data.u8',
|
||||
'frame_length' : 160,
|
||||
'feature_frame_length' : 55,
|
||||
'signal_frame_length' : 4,
|
||||
'feature_dtype' : 'float32',
|
||||
'signal_dtype' : 'uint8',
|
||||
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [36, 37], 'pitch_corr': [37, 38], 'lpc': [39, 55]},
|
||||
'signal_frame_layout' : {'last_signal' : 0, 'prediction' : 1, 'last_error': 2, 'error': 3} # signal, last_signal, error, prediction
|
||||
}
|
||||
|
||||
# lpcnet
|
||||
|
||||
lpcnet_config = {
|
||||
'frame_size' : 160,
|
||||
'gru_a_units' : 384,
|
||||
'gru_b_units' : 64,
|
||||
'feature_conditioning_dim' : 128,
|
||||
'feature_conv_kernel_size' : 3,
|
||||
'period_levels' : 257,
|
||||
'period_embedding_dim' : 64,
|
||||
'signal_embedding_dim' : 128,
|
||||
'signal_levels' : 256,
|
||||
'feature_dimension' : 19,
|
||||
'output_levels' : 256,
|
||||
'lpc_gamma' : 0.9,
|
||||
'features' : ['cepstrum', 'periods', 'pitch_corr'],
|
||||
'signals' : ['last_signal', 'prediction', 'last_error'],
|
||||
'input_layout' : { 'signals' : {'last_signal' : 0, 'prediction' : 1, 'last_error' : 2},
|
||||
'features' : {'cepstrum' : [0, 18], 'pitch_corr' : [18, 19]} },
|
||||
'target' : 'error',
|
||||
'feature_history' : 2,
|
||||
'feature_lookahead' : 2,
|
||||
'sparsification' : {
|
||||
'gru_a' : {
|
||||
'start' : 10000,
|
||||
'stop' : 30000,
|
||||
'interval' : 100,
|
||||
'exponent' : 3,
|
||||
'params' : {
|
||||
'W_hr' : (0.05, [4, 8], True),
|
||||
'W_hz' : (0.05, [4, 8], True),
|
||||
'W_hn' : (0.2, [4, 8], True)
|
||||
},
|
||||
},
|
||||
'gru_b' : {
|
||||
'start' : 10000,
|
||||
'stop' : 30000,
|
||||
'interval' : 100,
|
||||
'exponent' : 3,
|
||||
'params' : {
|
||||
'W_ir' : (0.5, [4, 8], False),
|
||||
'W_iz' : (0.5, [4, 8], False),
|
||||
'W_in' : (0.5, [4, 8], False)
|
||||
},
|
||||
}
|
||||
},
|
||||
'add_reference_phase' : False,
|
||||
'reference_phase_dim' : 0
|
||||
}
|
||||
|
||||
|
||||
|
||||
# multi rate
|
||||
subconditioning = {
|
||||
'subconditioning_a' : {
|
||||
'number_of_subsamples' : 2,
|
||||
'method' : 'modulative',
|
||||
'signals' : ['last_signal', 'prediction', 'last_error'],
|
||||
'pcm_embedding_size' : 64,
|
||||
'kwargs' : dict()
|
||||
|
||||
},
|
||||
'subconditioning_b' : {
|
||||
'number_of_subsamples' : 2,
|
||||
'method' : 'modulative',
|
||||
'signals' : ['last_signal', 'prediction', 'last_error'],
|
||||
'pcm_embedding_size' : 64,
|
||||
'kwargs' : dict()
|
||||
}
|
||||
}
|
||||
|
||||
multi_rate_lpcnet_config = lpcnet_config.copy()
|
||||
multi_rate_lpcnet_config['subconditioning'] = subconditioning
|
||||
|
||||
training_default = {
|
||||
'batch_size' : 256,
|
||||
'epochs' : 20,
|
||||
'lr' : 1e-3,
|
||||
'lr_decay_factor' : 2.5e-5,
|
||||
'adam_betas' : [0.9, 0.99],
|
||||
'frames_per_sample' : 15
|
||||
}
|
||||
|
||||
lpcnet_setup = {
|
||||
'dataset' : '/local/datasets/lpcnet_training',
|
||||
'lpcnet' : {'config' : lpcnet_config, 'model': 'lpcnet'},
|
||||
'training' : training_default
|
||||
}
|
||||
|
||||
multi_rate_lpcnet_setup = copy.deepcopy(lpcnet_setup)
|
||||
multi_rate_lpcnet_setup['lpcnet']['config'] = multi_rate_lpcnet_config
|
||||
multi_rate_lpcnet_setup['lpcnet']['model'] = 'multi_rate'
|
||||
|
||||
setup_dict = {
|
||||
'lpcnet' : lpcnet_setup,
|
||||
'multi_rate' : multi_rate_lpcnet_setup
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import math as m
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def ulaw2lin(u):
|
||||
scale_1 = 32768.0 / 255.0
|
||||
u = u - 128
|
||||
s = torch.sign(u)
|
||||
u = torch.abs(u)
|
||||
return s * scale_1 * (torch.exp(u / 128. * m.log(256)) - 1)
|
||||
|
||||
|
||||
def lin2ulawq(x):
|
||||
scale = 255.0 / 32768.0
|
||||
s = torch.sign(x)
|
||||
x = torch.abs(x)
|
||||
u = s * (128 * torch.log(1 + scale * x) / m.log(256))
|
||||
u = torch.clip(128 + torch.round(u), 0, 255)
|
||||
return u
|
||||
|
||||
def lin2ulaw(x):
|
||||
scale = 255.0 / 32768.0
|
||||
s = torch.sign(x)
|
||||
x = torch.abs(x)
|
||||
u = s * (128 * torch.log(1 + scale * x) / torch.log(256))
|
||||
u = torch.clip(128 + u, 0, 255)
|
||||
return u
|
||||
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import wave
|
||||
|
||||
def wavwrite16(filename, x, fs):
|
||||
""" writes x as int16 to file with name filename
|
||||
|
||||
If x.dtype is int16 x is written as is. Otherwise,
|
||||
it is scaled by 2**15 - 1 and converted to int16.
|
||||
"""
|
||||
if x.dtype != 'int16':
|
||||
x = ((2**15 - 1) * x).astype('int16')
|
||||
|
||||
with wave.open(filename, 'wb') as f:
|
||||
f.setparams((1, 2, fs, len(x), 'NONE', ""))
|
||||
f.writeframes(x.tobytes())
|
||||
Reference in New Issue
Block a user