add some code

This commit is contained in:
2025-09-05 13:25:11 +08:00
parent 9ff0a99e7a
commit 3cf1229a85
8911 changed files with 2535396 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from . import sparsification
from . import data
from . import pcm
from . import sample

View File

@@ -0,0 +1,141 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import torch
import numpy as np
def load_features(feature_file, version=2):
if version == 2:
layout = {
'cepstrum': [0,18],
'periods': [18, 19],
'pitch_corr': [19, 20],
'lpc': [20, 36]
}
frame_length = 36
elif version == 1:
layout = {
'cepstrum': [0,18],
'periods': [36, 37],
'pitch_corr': [37, 38],
'lpc': [39, 55],
}
frame_length = 55
else:
raise ValueError(f'unknown feature version: {version}')
raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
raw_features = raw_features.reshape((-1, frame_length))
features = torch.cat(
[
raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
],
dim=1
)
lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
ref_data = np.memmap(reference_data_path, dtype=np.int16)
signal = np.memmap(signal_path, dtype=np.int16)
signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
assert len(signal) % 160 == 0
num_frames = len(signal) // 160
mem = np.zeros(1)
for fr in range(len(signal)//160):
signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
new_data[:] = 0
N = len(signal) - offset
new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
def parse_warpq_scores(output_file):
""" extracts warpq scores from output file """
with open(output_file, "r") as f:
lines = f.readlines()
scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
return scores
def parse_stats_file(file):
with open(file, "r") as f:
lines = f.readlines()
mean = float(lines[0].split(":")[-1])
bt_mean = float(lines[1].split(":")[-1])
top_mean = float(lines[2].split(":")[-1])
return mean, bt_mean, top_mean
def collect_test_stats(test_folder):
""" collects statistics for all discovered metrics from test folder """
metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
results = dict()
content = os.listdir(test_folder)
stats_files = [file for file in content if file.startswith('stats_')]
for file in stats_files:
metric = file[len("stats_") : -len(".txt")]
if metric not in metrics:
print(f"warning: unknown metric {metric}")
mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
results[metric] = [mean, bt_mean, top_mean]
return results

View File

@@ -0,0 +1,234 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" module for inspecting models during inference """
import os
import yaml
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import numpy as np
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
_state = dict()
_folder = 'endoscopy'
def get_gru_gates(gru, input, state):
hidden_size = gru.hidden_size
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
# reset gate
start, stop = 0 * hidden_size, 1 * hidden_size
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# update gate
start, stop = 1 * hidden_size, 2 * hidden_size
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# new gate
start, stop = 2 * hidden_size, 3 * hidden_size
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
def init(folder='endoscopy'):
""" sets up output folder for endoscopy data """
global _folder
_folder = folder
if not os.path.exists(folder):
os.makedirs(folder)
else:
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
def write_data(key, data, fs):
""" appends data to previous data written under key """
global _state
# convert to numpy if torch.Tensor is given
if isinstance(data, torch.Tensor):
data = data.detach().numpy()
if not key in _state:
_state[key] = {
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
'fs' : fs,
'dim' : tuple(data.shape),
'dtype' : str(data.dtype)
}
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
else:
if _state[key]['fs'] != fs:
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
if _state[key]['dtype'] != str(data.dtype):
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
if _state[key]['dim'] != tuple(data.shape):
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
_state[key]['fid'].write(data.tobytes())
def close(folder='endoscopy'):
""" clean up """
for key in _state.keys():
_state[key]['fid'].close()
def read_data(folder='endoscopy'):
""" retrieves written data as numpy arrays """
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
return_dict = dict()
for key in keys:
with open(os.path.join(folder, key + '.yml'), 'r') as f:
value = yaml.load(f.read(), yaml.FullLoader)
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
data = np.frombuffer(f.read(), dtype=value['dtype'])
value['data'] = data.reshape((-1,) + value['dim'])
return_dict[key] = value
return return_dict
def get_best_reshape(shape, target_ratio=1):
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return (1,)
num_columns = int((pixel_count / target_ratio)**.5)
while (pixel_count % num_columns):
num_columns -= 1
num_rows = pixel_count // num_columns
return (num_rows, num_columns)
def get_type_and_shape(shape):
# can happen if data is one dimensional
if len(shape) == 0:
shape = (1,)
# calculate pixel count
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return 'plot', (1, )
# stay with shape if already 2-dimensional
if len(shape) == 2:
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
return 'image', shape
return 'image', get_best_reshape(shape)
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
# determine plot setup
num_keys = len(data.keys())
num_rows = int((num_keys * 3/4) ** .5)
num_cols = (num_keys + num_rows - 1) // num_rows
fig, axs = plt.subplots(num_rows, num_cols)
fig.set_size_inches(num_cols * 5, num_rows * 5)
display = dict()
fs_max = max([val['fs'] for val in data.values()])
num_samples = max([val['data'].shape[0] for val in data.values()])
keys = sorted(data.keys())
# inspect data
for i, key in enumerate(keys):
axs[i // num_cols, i % num_cols].title.set_text(key)
display[key] = dict()
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
display[key]['down_factor'] = data[key]['fs'] / fs_max
start_index = max(start_index, half_signal_window_length)
while stop_index < 0:
stop_index += num_samples
stop_index = min(stop_index, num_samples - half_signal_window_length)
# actual plotting
frames = []
for index in range(start_index, stop_index):
ims = []
for i, key in enumerate(keys):
feature_index = int(round(index * display[key]['down_factor']))
if display[key]['type'] == 'plot':
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
elif display[key]['type'] == 'image':
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
frames.append(ims)
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
if not filename.endswith('.mp4'):
filename += '.mp4'
ani.save(filename)

View File

@@ -0,0 +1,3 @@
from .dual_fc import DualFC
from .subconditioner import AdditiveSubconditioner, ModulativeSubconditioner, ConcatenativeSubconditioner
from .pcm_embeddings import PCMEmbedding, DifferentiablePCMEmbedding

View File

@@ -0,0 +1,44 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from torch import nn
class DualFC(nn.Module):
def __init__(self, input_dim, output_dim):
super(DualFC, self).__init__()
self.dense1 = nn.Linear(input_dim, output_dim)
self.dense2 = nn.Linear(input_dim, output_dim)
self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
self.beta = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
def forward(self, x):
return self.alpha * torch.tanh(self.dense1(x)) + self.beta * torch.tanh(self.dense2(x))

View File

@@ -0,0 +1,71 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" module implementing PCM embeddings for LPCNet """
import math as m
import torch
from torch import nn
class PCMEmbedding(nn.Module):
def __init__(self, embed_dim=128, num_levels=256):
super(PCMEmbedding, self).__init__()
self.embed_dim = embed_dim
self.num_levels = num_levels
self.embedding = nn.Embedding(self.num_levels, self.num_dim)
# initialize
with torch.no_grad():
num_rows, num_cols = self.num_levels, self.embed_dim
a = m.sqrt(12) * (torch.rand(num_rows, num_cols) - 0.5)
for i in range(num_rows):
a[i, :] += m.sqrt(12) * (i - num_rows / 2)
self.embedding.weight[:, :] = 0.1 * a
def forward(self, x):
return self.embeddint(x)
class DifferentiablePCMEmbedding(PCMEmbedding):
def __init__(self, embed_dim, num_levels=256):
super(DifferentiablePCMEmbedding, self).__init__(embed_dim, num_levels)
def forward(self, x):
x_int = (x - torch.floor(x)).detach().long()
x_frac = x - x_int
x_next = torch.minimum(x_int + 1, self.num_levels)
embed_0 = self.embedding(x_int)
embed_1 = self.embedding(x_next)
return (1 - x_frac) * embed_0 + x_frac * embed_1

View File

@@ -0,0 +1,497 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
from re import sub
import torch
from torch import nn
def get_subconditioner( method,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
**kwargs):
subconditioner_dict = {
'additive' : AdditiveSubconditioner,
'concatenative' : ConcatenativeSubconditioner,
'modulative' : ModulativeSubconditioner
}
return subconditioner_dict[method](number_of_subsamples,
pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
class Subconditioner(nn.Module):
def __init__(self):
""" upsampling by subconditioning
Upsamples a sequence of states conditioning on pcm signals and
optionally a feature vector.
"""
super(Subconditioner, self).__init__()
def forward(self, states, signals, features=None):
raise Exception("Base class should not be called")
def single_step(self, index, state, signals, features):
raise Exception("Base class should not be called")
def get_output_dim(self, index):
raise Exception("Base class should not be called")
class AdditiveSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
**kwargs):
""" subconditioning by addition """
super(AdditiveSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
if self.pcm_embedding_size != self.state_size:
raise ValueError('For additive subconditioning state and embedding '
+ f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
self.embeddings = [None]
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.sum(embed, dim=2)
new_states = new_states + embed
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.sum(embed_signals, dim=-2)
c_state = state + c
return c_state
def get_output_dim(self, index):
return self.state_size
def get_average_flops_per_step(self):
s = self.number_of_subsamples
flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
return flops
class ConcatenativeSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
recurrent=True,
**kwargs):
""" subconditioning by concatenation """
super(ConcatenativeSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
self.recurrent = recurrent
self.embeddings = []
start_index = 0
if self.recurrent:
start_index = 1
self.embeddings.append(None)
for i in range(start_index, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
if self.recurrent:
c_states = [states]
start = 1
else:
c_states = []
start = 0
new_states = states
for i in range(start, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
if self.recurrent:
new_states = torch.cat((new_states, embed), dim=-1)
else:
new_states = torch.cat((states, embed), dim=-1)
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0 and self.recurrent:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.flatten(embed_signals, -2)
if not self.recurrent and index > 0:
# overwrite previous conditioning vector
c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
else:
c_state = torch.cat((state, c), dim=-1)
return c_state
return c_state
def get_average_flops_per_step(self):
return 0
def get_output_dim(self, index):
if self.recurrent:
return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
else:
return self.state_size + self.pcm_embedding_size * self.number_of_signals
class ModulativeSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
state_recurrent=False,
**kwargs):
""" subconditioning by modulation """
super(ModulativeSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
self.state_recurrent = state_recurrent
self.hidden_size = self.pcm_embedding_size * self.number_of_signals
if self.state_recurrent:
self.hidden_size += self.pcm_embedding_size
self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
self.embeddings = [None]
self.alphas = [None]
self.betas = [None]
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
self.add_module('alpha_dense_' + str(i), self.alphas[-1])
self.betas.append(nn.Linear(self.hidden_size, self.state_size))
self.add_module('beta_dense_' + str(i), self.betas[-1])
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
if self.state_recurrent:
comp_states = self.state_transform(new_states)
embed = torch.cat((embed, comp_states), dim=-1)
alpha = torch.tanh(self.alphas[i](embed))
beta = torch.tanh(self.betas[i](embed))
# new state obtained by modulating previous state
new_states = torch.tanh((1 + alpha) * new_states + beta)
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.flatten(embed_signals, -2)
if self.state_recurrent:
r_state = self.state_transform(state)
c = torch.cat((c, r_state), dim=-1)
alpha = torch.tanh(self.alphas[index](c))
beta = torch.tanh(self.betas[index](c))
c_state = torch.tanh((1 + alpha) * state + beta)
return c_state
return c_state
def get_output_dim(self, index):
return self.state_size
def get_average_flops_per_step(self):
s = self.number_of_subsamples
# estimate activation by 10 flops
# c_state = torch.tanh((1 + alpha) * state + beta)
flops = 13 * self.state_size
# hidden size
hidden_size = self.number_of_signals * self.pcm_embedding_size
if self.state_recurrent:
hidden_size += self.pcm_embedding_size
# counting 2 * A * B flops for Linear(A, B)
# alpha = torch.tanh(self.alphas[index](c))
# beta = torch.tanh(self.betas[index](c))
flops += 4 * hidden_size * self.state_size + 20 * self.state_size
# r_state = self.state_transform(state)
if self.state_recurrent:
flops += 2 * self.state_size * self.pcm_embedding_size
# average over steps
flops *= (s - 1) / s
return flops
class ComparitiveSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
error_index=-1,
apply_gate=True,
normalize=False):
""" subconditioning by comparison """
super(ComparitiveSubconditioner, self).__init__()
self.comparison_size = self.pcm_embedding_size
self.error_position = error_index
self.apply_gate = apply_gate
self.normalize = normalize
self.state_transform = nn.Linear(self.state_size, self.comparison_size)
self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
if self.apply_gate:
self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
# embeddings and state transforms
self.embeddings = [None]
self.alpha_denses = [None]
self.beta_denses = [None]
self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
self.add_module('state_transform_0', self.state_transforms[0])
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
state_transform = nn.Linear(self.state_size, self.comparison_size)
self.add_module('state_transform_' + str(i), state_transform)
self.state_transforms.append(state_transform)
self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
def forward(self, states, signals):
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
comp_states = self.state_transforms[i](new_states)
alpha = torch.tanh(self.alpha_dense(embed))
beta = torch.tanh(self.beta_dense(embed))
# new state obtained by modulating previous state
new_states = torch.tanh((1 + alpha) * comp_states + beta)
c_states.append(new_states)
return c_states

View File

@@ -0,0 +1,65 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
def find(a, v):
try:
idx = a.index(v)
except:
idx = -1
return idx
def interleave_tensors(tensors, dim=-2):
""" interleave list of tensors along sequence dimension """
x = torch.cat([x.unsqueeze(dim) for x in tensors], dim=dim)
x = torch.flatten(x, dim - 1, dim)
return x
def _interleave(x, pcm_levels=256):
repeats = pcm_levels // (2*x.size(-1))
x = x.unsqueeze(-1)
p = torch.flatten(torch.repeat_interleave(torch.cat((x, 1 - x), dim=-1), repeats, dim=-1), -2)
return p
def get_pdf_from_tree(x):
pcm_levels = x.size(-1)
p = _interleave(x[..., 1:2])
n = 4
while n <= pcm_levels:
p = p * _interleave(x[..., n//2:n])
n *= 2
return p

View File

@@ -0,0 +1,35 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
def clip_to_int16(x):
int_min = -2**15
int_max = 2**15 - 1
x_clipped = max(int_min, min(x, int_max))
return x_clipped

View File

@@ -0,0 +1,44 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
def sample_excitation(probs, pitch_corr):
norm = lambda x : x / (x.sum() + 1e-18)
# lowering the temperature
probs = norm(probs ** (1 + max(0, 1.5 * pitch_corr - 0.5)))
# cut-off tails
probs = norm(torch.maximum(probs - 0.002 , torch.FloatTensor([0])))
# sample
exc = torch.multinomial(probs.squeeze(), 1)
return exc

View File

@@ -0,0 +1,2 @@
from .gru_sparsifier import GRUSparsifier
from .common import sparsify_matrix, calculate_gru_flops_per_step

View File

@@ -0,0 +1,121 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
""" sparsifies matrix with specified block size
Parameters:
-----------
matrix : torch.tensor
matrix to sparsify
density : int
target density
block_size : [int, int]
block size dimensions
keep_diagonal : bool
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
"""
m, n = matrix.shape
m1, n1 = block_size
if m % m1 or n % n1:
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
# extract diagonal if keep_diagonal = True
if keep_diagonal:
if m != n:
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
to_spare = torch.diag(torch.diag(matrix))
matrix = matrix - to_spare
else:
to_spare = torch.zeros_like(matrix)
# calculate energy in sub-blocks
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
x = x ** 2
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
number_of_blocks = (m * n) // (m1 * n1)
number_of_survivors = round(number_of_blocks * density)
# masking threshold
if number_of_survivors == 0:
threshold = 0
else:
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
# create mask
mask = torch.ones_like(block_energies)
mask[block_energies < threshold] = 0
mask = torch.repeat_interleave(mask, m1, dim=0)
mask = torch.repeat_interleave(mask, n1, dim=1)
# perform masking
masked_matrix = mask * matrix + to_spare
if return_mask:
return masked_matrix, mask
else:
return masked_matrix
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
input_size = gru.input_size
hidden_size = gru.hidden_size
flops = 0
input_density = (
sparsification_dict.get('W_ir', [1])[0]
+ sparsification_dict.get('W_in', [1])[0]
+ sparsification_dict.get('W_iz', [1])[0]
) / 3
recurrent_density = (
sparsification_dict.get('W_hr', [1])[0]
+ sparsification_dict.get('W_hn', [1])[0]
+ sparsification_dict.get('W_hz', [1])[0]
) / 3
# input matrix vector multiplications
if not drop_input:
flops += 2 * 3 * input_size * hidden_size * input_density
# recurrent matrix vector multiplications
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
# biases
flops += 6 * hidden_size
# activations estimated by 10 flops per activation
flops += 30 * hidden_size
return flops

View File

@@ -0,0 +1,187 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import torch
from .common import sparsify_matrix
class GRUSparsifier:
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Parameters:
-----------
task_list : list
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
should be kept.
start : int
training step after which sparsification will be started.
stop : int
training step after which sparsification will be completed.
interval : int
sparsification interval for steps between start and stop. After stop sparsification will be
carried out after every call to GRUSparsifier.step()
exponent : float
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
with density (alpha + target_density * (1 * alpha)), where
alpha = ((stop - i) / (start - stop)) ** exponent
Example:
--------
>>> import torch
>>> gru = torch.nn.GRU(10, 20)
>>> sparsify_dict = {
... 'W_ir' : (0.5, [2, 2], False),
... 'W_iz' : (0.6, [2, 2], False),
... 'W_in' : (0.7, [2, 2], False),
... 'W_hr' : (0.1, [4, 4], True),
... 'W_hz' : (0.2, [4, 4], True),
... 'W_hn' : (0.3, [4, 4], True),
... }
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
>>> for i in range(100):
... sparsifier.step()
"""
# just copying parameters...
self.start = start
self.stop = stop
self.interval = interval
self.exponent = exponent
self.task_list = task_list
# ... and setting counter to 0
self.step_counter = 0
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
def step(self, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
verbose : bool
if true, densities are printed out
Returns:
--------
None
"""
# compute current interpolation factor
self.step_counter += 1
if self.step_counter < self.start:
return
elif self.step_counter < self.stop:
# update only every self.interval-th interval
if self.step_counter % self.interval:
return
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
else:
alpha = 0
with torch.no_grad():
for gru, params in self.task_list:
hidden_size = gru.hidden_size
# input weights
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
if key in params:
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
density, # density
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
print(f"sparsification mask {key} changed for gru {gru}")
self.last_masks[key] = new_mask
# recurrent weights
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
if key in params:
density = alpha + (1 - alpha) * params[key][0]
if verbose:
print(f"[{self.step_counter}]: {key} density: {density}")
gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
density,
params[key][1], # block_size
params[key][2], # keep_diagonal (might want to set this to False)
return_mask=True
)
if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
print(f"sparsification mask {key} changed for gru {gru}")
self.last_masks[key] = new_mask
if __name__ == "__main__":
print("Testing sparsifier")
gru = torch.nn.GRU(10, 20)
sparsify_dict = {
'W_ir' : (0.5, [2, 2], False),
'W_iz' : (0.6, [2, 2], False),
'W_in' : (0.7, [2, 2], False),
'W_hr' : (0.1, [4, 4], True),
'W_hz' : (0.2, [4, 4], True),
'W_hn' : (0.3, [4, 4], True),
}
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
for i in range(100):
sparsifier.step(verbose=True)

View File

@@ -0,0 +1,157 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
from models import multi_rate_lpcnet
import copy
setup_dict = dict()
dataset_template_v2 = {
'version' : 2,
'feature_file' : 'features.f32',
'signal_file' : 'data.s16',
'frame_length' : 160,
'feature_frame_length' : 36,
'signal_frame_length' : 2,
'feature_dtype' : 'float32',
'signal_dtype' : 'int16',
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [18, 19], 'pitch_corr': [19, 20], 'lpc': [20, 36]},
'signal_frame_layout' : {'last_signal' : 0, 'signal': 1} # signal, last_signal, error, prediction
}
dataset_template_v1 = {
'version' : 1,
'feature_file' : 'features.f32',
'signal_file' : 'data.u8',
'frame_length' : 160,
'feature_frame_length' : 55,
'signal_frame_length' : 4,
'feature_dtype' : 'float32',
'signal_dtype' : 'uint8',
'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [36, 37], 'pitch_corr': [37, 38], 'lpc': [39, 55]},
'signal_frame_layout' : {'last_signal' : 0, 'prediction' : 1, 'last_error': 2, 'error': 3} # signal, last_signal, error, prediction
}
# lpcnet
lpcnet_config = {
'frame_size' : 160,
'gru_a_units' : 384,
'gru_b_units' : 64,
'feature_conditioning_dim' : 128,
'feature_conv_kernel_size' : 3,
'period_levels' : 257,
'period_embedding_dim' : 64,
'signal_embedding_dim' : 128,
'signal_levels' : 256,
'feature_dimension' : 19,
'output_levels' : 256,
'lpc_gamma' : 0.9,
'features' : ['cepstrum', 'periods', 'pitch_corr'],
'signals' : ['last_signal', 'prediction', 'last_error'],
'input_layout' : { 'signals' : {'last_signal' : 0, 'prediction' : 1, 'last_error' : 2},
'features' : {'cepstrum' : [0, 18], 'pitch_corr' : [18, 19]} },
'target' : 'error',
'feature_history' : 2,
'feature_lookahead' : 2,
'sparsification' : {
'gru_a' : {
'start' : 10000,
'stop' : 30000,
'interval' : 100,
'exponent' : 3,
'params' : {
'W_hr' : (0.05, [4, 8], True),
'W_hz' : (0.05, [4, 8], True),
'W_hn' : (0.2, [4, 8], True)
},
},
'gru_b' : {
'start' : 10000,
'stop' : 30000,
'interval' : 100,
'exponent' : 3,
'params' : {
'W_ir' : (0.5, [4, 8], False),
'W_iz' : (0.5, [4, 8], False),
'W_in' : (0.5, [4, 8], False)
},
}
},
'add_reference_phase' : False,
'reference_phase_dim' : 0
}
# multi rate
subconditioning = {
'subconditioning_a' : {
'number_of_subsamples' : 2,
'method' : 'modulative',
'signals' : ['last_signal', 'prediction', 'last_error'],
'pcm_embedding_size' : 64,
'kwargs' : dict()
},
'subconditioning_b' : {
'number_of_subsamples' : 2,
'method' : 'modulative',
'signals' : ['last_signal', 'prediction', 'last_error'],
'pcm_embedding_size' : 64,
'kwargs' : dict()
}
}
multi_rate_lpcnet_config = lpcnet_config.copy()
multi_rate_lpcnet_config['subconditioning'] = subconditioning
training_default = {
'batch_size' : 256,
'epochs' : 20,
'lr' : 1e-3,
'lr_decay_factor' : 2.5e-5,
'adam_betas' : [0.9, 0.99],
'frames_per_sample' : 15
}
lpcnet_setup = {
'dataset' : '/local/datasets/lpcnet_training',
'lpcnet' : {'config' : lpcnet_config, 'model': 'lpcnet'},
'training' : training_default
}
multi_rate_lpcnet_setup = copy.deepcopy(lpcnet_setup)
multi_rate_lpcnet_setup['lpcnet']['config'] = multi_rate_lpcnet_config
multi_rate_lpcnet_setup['lpcnet']['model'] = 'multi_rate'
setup_dict = {
'lpcnet' : lpcnet_setup,
'multi_rate' : multi_rate_lpcnet_setup
}

View File

@@ -0,0 +1,58 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import math as m
import torch
def ulaw2lin(u):
scale_1 = 32768.0 / 255.0
u = u - 128
s = torch.sign(u)
u = torch.abs(u)
return s * scale_1 * (torch.exp(u / 128. * m.log(256)) - 1)
def lin2ulawq(x):
scale = 255.0 / 32768.0
s = torch.sign(x)
x = torch.abs(x)
u = s * (128 * torch.log(1 + scale * x) / m.log(256))
u = torch.clip(128 + torch.round(u), 0, 255)
return u
def lin2ulaw(x):
scale = 255.0 / 32768.0
s = torch.sign(x)
x = torch.abs(x)
u = s * (128 * torch.log(1 + scale * x) / torch.log(256))
u = torch.clip(128 + u, 0, 255)
return u

View File

@@ -0,0 +1,43 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import wave
def wavwrite16(filename, x, fs):
""" writes x as int16 to file with name filename
If x.dtype is int16 x is written as is. Otherwise,
it is scaled by 2**15 - 1 and converted to int16.
"""
if x.dtype != 'int16':
x = ((2**15 - 1) * x).astype('int16')
with wave.open(filename, 'wb') as f:
f.setparams((1, 2, fs, len(x), 'NONE', ""))
f.writeframes(x.tobytes())