add some code
This commit is contained in:
55
managed_components/78__esp-opus/dnn/torch/rdovae/README.md
Normal file
55
managed_components/78__esp-opus/dnn/torch/rdovae/README.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Deep REDundancy (DRED) with RDO-VAE
|
||||
|
||||
This is a rate-distortion-optimized variational autoencoder (RDO-VAE) designed
|
||||
to coding redundancy information. Pre-trained models are provided as C code
|
||||
in the dnn/ directory with the corresponding model in dnn/models/ directory
|
||||
(name starts with rdovae_). If you don't want to train a new DRED model, you can
|
||||
skip straight to the Inference section.
|
||||
|
||||
## Data preparation
|
||||
|
||||
First, fetch all the data from the datasets.txt file using:
|
||||
```
|
||||
./download_datasets.sh
|
||||
```
|
||||
|
||||
Then concatenate and resample the data into a single 16-kHz file:
|
||||
```
|
||||
./process_speech.sh
|
||||
```
|
||||
The script will produce an all_speech.pcm speech file in raw 16-bit PCM format.
|
||||
|
||||
|
||||
For data preparation you need to build Opus as detailed in the top-level README.
|
||||
You will need to use the --enable-dred configure option.
|
||||
The build will produce an executable named "dump_data".
|
||||
To prepare the training data, run:
|
||||
```
|
||||
./dump_data -train all_speech.pcm all_features.f32 /dev/null
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
To perform training, run the following command:
|
||||
```
|
||||
python ./train_rdovae.py --sequence-length 400 --split-mode random_split --state-dim 80 --batch-size 512 --epochs 400 --lambda-max 0.04 --lr 0.003 --lr-decay-factor 0.0001 all_features.f32 output_dir
|
||||
```
|
||||
The final model will be in output_dir/checkpoints/chechpoint_400.pth.
|
||||
|
||||
The model can be converted to C using:
|
||||
```
|
||||
python export_rdovae_weights.py output_dir/checkpoints/chechpoint_400.pth dred_c_dir
|
||||
```
|
||||
which will create a number of C source and header files in the fargan_c_dir directory.
|
||||
Copy these files to the opus/dnn/ directory (replacing the existing ones) and recompile Opus.
|
||||
|
||||
## Inference
|
||||
|
||||
DRED is integrated within the Opus codec and can be evaluated using the opus_demo
|
||||
executable. For example:
|
||||
```
|
||||
./opus_demo voip 16000 1 64000 -loss 50 -dred 100 -sim_loss 50 input.pcm output.pcm
|
||||
```
|
||||
Will tell the encoder to encode a 16 kHz raw audio file at 64 kb/s using up to 1 second
|
||||
of redundancy (units are based on 10-ms) and then simulate 50% loss. Refer to `opus_demo --help`
|
||||
for more details.
|
||||
@@ -0,0 +1,6 @@
|
||||
mkdir datasets
|
||||
cd datasets
|
||||
for i in `grep https ../../../datasets.txt`
|
||||
do
|
||||
wget $i
|
||||
done
|
||||
@@ -0,0 +1,365 @@
|
||||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('checkpoint', type=str, help='rdovae model checkpoint')
|
||||
parser.add_argument('output_dir', type=str, help='output folder')
|
||||
parser.add_argument('--format', choices=['C', 'numpy'], help='output format, default: C', default='C')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from rdovae import RDOVAE
|
||||
from wexchange.torch import dump_torch_weights
|
||||
from wexchange.c_export import CWriter, print_vector
|
||||
|
||||
def print_xml(xmlout, val, param, anchor, name):
|
||||
xmlout.write(
|
||||
f"""
|
||||
<table anchor="{anchor}_{name}">
|
||||
<name>{param} values for {name}</name>
|
||||
<thead>
|
||||
<tr><th>k</th><th>Q0</th><th>Q1</th><th>Q2</th><th>Q3</th><th>Q4</th><th>Q5</th><th>Q6</th><th>Q7</th><th>Q8</th><th>Q9</th><th>Q10</th><th>Q11</th><th>Q12</th><th>Q13</th><th>Q14</th><th>Q15</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
""")
|
||||
for k in range(val.shape[1]):
|
||||
xmlout.write(f" <tr><th>{k}</th>")
|
||||
for j in range(val.shape[0]):
|
||||
xmlout.write(f"<th>{val[j][k]}</th>")
|
||||
xmlout.write("</tr>\n")
|
||||
xmlout.write(
|
||||
f"""
|
||||
</tbody>
|
||||
</table>
|
||||
""")
|
||||
def dump_statistical_model(writer, w, name, xmlout):
|
||||
levels = w.shape[0]
|
||||
|
||||
print("printing statistical model")
|
||||
quant_scales = torch.nn.functional.softplus(w[:, 0, :]).numpy()
|
||||
dead_zone = 0.05 * torch.nn.functional.softplus(w[:, 1, :]).numpy()
|
||||
r = torch.sigmoid(w[:, 5 , :]).numpy()
|
||||
p0 = torch.sigmoid(w[:, 4 , :]).numpy()
|
||||
p0 = 1 - r ** (0.5 + 0.5 * p0)
|
||||
|
||||
scales_norm = 255./256./(1e-15+np.max(quant_scales,axis=0))
|
||||
quant_scales = quant_scales*scales_norm
|
||||
quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
|
||||
dead_zone_q8 = np.clip(np.round(dead_zone * 2**8), 0, 255).astype(np.uint16)
|
||||
r_q8 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
|
||||
p0_q8 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
|
||||
|
||||
mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255)
|
||||
quant_scales_q8 = quant_scales_q8[:, mask]
|
||||
dead_zone_q8 = dead_zone_q8[:, mask]
|
||||
r_q8 = r_q8[:, mask]
|
||||
p0_q8 = p0_q8[:, mask]
|
||||
N = r_q8.shape[-1]
|
||||
|
||||
print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint8', static=False)
|
||||
print_vector(writer.source, dead_zone_q8, f'dred_{name}_dead_zone_q8', dtype='opus_uint8', static=False)
|
||||
print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
|
||||
print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
|
||||
|
||||
print_xml(xmlout, quant_scales_q8, "Scale", "scale", name)
|
||||
print_xml(xmlout, dead_zone_q8, "Dead zone", "deadzone", name)
|
||||
print_xml(xmlout, r_q8, "Decay (r)", "decay", name)
|
||||
print_xml(xmlout, p0_q8, "P(0)", "p0", name)
|
||||
|
||||
writer.header.write(
|
||||
f"""
|
||||
extern const opus_uint8 dred_{name}_quant_scales_q8[{levels * N}];
|
||||
extern const opus_uint8 dred_{name}_dead_zone_q8[{levels * N}];
|
||||
extern const opus_uint8 dred_{name}_r_q8[{levels * N}];
|
||||
extern const opus_uint8 dred_{name}_p0_q8[{levels * N}];
|
||||
|
||||
"""
|
||||
)
|
||||
return N, mask, torch.tensor(scales_norm[mask])
|
||||
|
||||
|
||||
def c_export(args, model):
|
||||
|
||||
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
|
||||
|
||||
enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message, model_struct_name='RDOVAEEnc')
|
||||
dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message, model_struct_name='RDOVAEDec')
|
||||
stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats_data"), message=message, enable_binary_blob=False)
|
||||
constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True, enable_binary_blob=False)
|
||||
xmlout = open("stats.xml", "w")
|
||||
|
||||
# some custom includes
|
||||
for writer in [enc_writer, dec_writer]:
|
||||
writer.header.write(
|
||||
f"""
|
||||
#include "opus_types.h"
|
||||
|
||||
#include "dred_rdovae.h"
|
||||
|
||||
#include "dred_rdovae_constants.h"
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
stats_writer.header.write(
|
||||
f"""
|
||||
#include "opus_types.h"
|
||||
|
||||
#include "dred_rdovae_constants.h"
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
latent_out = model.get_submodule('core_encoder.module.z_dense')
|
||||
state_out = model.get_submodule('core_encoder.module.state_dense_2')
|
||||
orig_latent_dim = latent_out.weight.shape[0]
|
||||
orig_state_dim = state_out.weight.shape[0]
|
||||
# statistical model
|
||||
qembedding = model.statistical_model.quant_embedding.weight.detach()
|
||||
levels = qembedding.shape[0]
|
||||
qembedding = torch.reshape(qembedding, (levels, 6, -1))
|
||||
|
||||
latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent', xmlout)
|
||||
state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state', xmlout)
|
||||
|
||||
padded_latent_dim = (latent_dim+7)//8*8
|
||||
latent_pad = padded_latent_dim - latent_dim;
|
||||
w = latent_out.weight[latent_mask,:]
|
||||
w = w/latent_scale[:, None]
|
||||
w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0)
|
||||
b = latent_out.bias[latent_mask]
|
||||
b = b/latent_scale
|
||||
b = torch.cat([b, torch.zeros(latent_pad)], dim=0)
|
||||
latent_out.weight = torch.nn.Parameter(w)
|
||||
latent_out.bias = torch.nn.Parameter(b)
|
||||
|
||||
padded_state_dim = (state_dim+7)//8*8
|
||||
state_pad = padded_state_dim - state_dim;
|
||||
w = state_out.weight[state_mask,:]
|
||||
w = w/state_scale[:, None]
|
||||
w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0)
|
||||
b = state_out.bias[state_mask]
|
||||
b = b/state_scale
|
||||
b = torch.cat([b, torch.zeros(state_pad)], dim=0)
|
||||
state_out.weight = torch.nn.Parameter(w)
|
||||
state_out.bias = torch.nn.Parameter(b)
|
||||
|
||||
latent_in = model.get_submodule('core_decoder.module.dense_1')
|
||||
state_in = model.get_submodule('core_decoder.module.hidden_init')
|
||||
latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask]*latent_scale)
|
||||
state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask]*state_scale)
|
||||
|
||||
# encoder
|
||||
encoder_dense_layers = [
|
||||
('core_encoder.module.dense_1' , 'enc_dense1', 'TANH', False,),
|
||||
('core_encoder.module.z_dense' , 'enc_zdense', 'LINEAR', True,),
|
||||
('core_encoder.module.state_dense_1' , 'gdense1' , 'TANH', True,),
|
||||
('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH', True)
|
||||
]
|
||||
|
||||
for name, export_name, _, quantize in encoder_dense_layers:
|
||||
layer = model.get_submodule(name)
|
||||
dump_torch_weights(enc_writer, layer, name=export_name, verbose=True, quantize=quantize, scale=None)
|
||||
|
||||
|
||||
encoder_gru_layers = [
|
||||
('core_encoder.module.gru1' , 'enc_gru1', 'TANH', True),
|
||||
('core_encoder.module.gru2' , 'enc_gru2', 'TANH', True),
|
||||
('core_encoder.module.gru3' , 'enc_gru3', 'TANH', True),
|
||||
('core_encoder.module.gru4' , 'enc_gru4', 'TANH', True),
|
||||
('core_encoder.module.gru5' , 'enc_gru5', 'TANH', True),
|
||||
]
|
||||
|
||||
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=quantize, scale=None, recurrent_scale=None)
|
||||
for name, export_name, _, quantize in encoder_gru_layers])
|
||||
|
||||
|
||||
encoder_conv_layers = [
|
||||
('core_encoder.module.conv1.conv' , 'enc_conv1', 'TANH', True),
|
||||
('core_encoder.module.conv2.conv' , 'enc_conv2', 'TANH', True),
|
||||
('core_encoder.module.conv3.conv' , 'enc_conv3', 'TANH', True),
|
||||
('core_encoder.module.conv4.conv' , 'enc_conv4', 'TANH', True),
|
||||
('core_encoder.module.conv5.conv' , 'enc_conv5', 'TANH', True),
|
||||
]
|
||||
|
||||
enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, quantize=quantize, scale=None) for name, export_name, _, quantize in encoder_conv_layers])
|
||||
|
||||
|
||||
del enc_writer
|
||||
|
||||
# decoder
|
||||
decoder_dense_layers = [
|
||||
('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False),
|
||||
('core_decoder.module.glu1.gate' , 'dec_glu1', 'TANH', True),
|
||||
('core_decoder.module.glu2.gate' , 'dec_glu2', 'TANH', True),
|
||||
('core_decoder.module.glu3.gate' , 'dec_glu3', 'TANH', True),
|
||||
('core_decoder.module.glu4.gate' , 'dec_glu4', 'TANH', True),
|
||||
('core_decoder.module.glu5.gate' , 'dec_glu5', 'TANH', True),
|
||||
('core_decoder.module.output' , 'dec_output', 'LINEAR', True),
|
||||
('core_decoder.module.hidden_init' , 'dec_hidden_init', 'TANH', False),
|
||||
('core_decoder.module.gru_init' , 'dec_gru_init','TANH', True),
|
||||
]
|
||||
|
||||
for name, export_name, _, quantize in decoder_dense_layers:
|
||||
layer = model.get_submodule(name)
|
||||
dump_torch_weights(dec_writer, layer, name=export_name, verbose=True, quantize=quantize, scale=None)
|
||||
|
||||
|
||||
decoder_gru_layers = [
|
||||
('core_decoder.module.gru1' , 'dec_gru1', 'TANH', True),
|
||||
('core_decoder.module.gru2' , 'dec_gru2', 'TANH', True),
|
||||
('core_decoder.module.gru3' , 'dec_gru3', 'TANH', True),
|
||||
('core_decoder.module.gru4' , 'dec_gru4', 'TANH', True),
|
||||
('core_decoder.module.gru5' , 'dec_gru5', 'TANH', True),
|
||||
]
|
||||
|
||||
dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=quantize, scale=None, recurrent_scale=None)
|
||||
for name, export_name, _, quantize in decoder_gru_layers])
|
||||
|
||||
decoder_conv_layers = [
|
||||
('core_decoder.module.conv1.conv' , 'dec_conv1', 'TANH', True),
|
||||
('core_decoder.module.conv2.conv' , 'dec_conv2', 'TANH', True),
|
||||
('core_decoder.module.conv3.conv' , 'dec_conv3', 'TANH', True),
|
||||
('core_decoder.module.conv4.conv' , 'dec_conv4', 'TANH', True),
|
||||
('core_decoder.module.conv5.conv' , 'dec_conv5', 'TANH', True),
|
||||
]
|
||||
|
||||
dec_max_conv_inputs = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, quantize=quantize, scale=None) for name, export_name, _, quantize in decoder_conv_layers])
|
||||
|
||||
del dec_writer
|
||||
|
||||
del stats_writer
|
||||
|
||||
# constants
|
||||
constants_writer.header.write(
|
||||
f"""
|
||||
#define DRED_NUM_FEATURES {model.feature_dim}
|
||||
|
||||
#define DRED_LATENT_DIM {latent_dim}
|
||||
|
||||
#define DRED_STATE_DIM {state_dim}
|
||||
|
||||
#define DRED_PADDED_LATENT_DIM {padded_latent_dim}
|
||||
|
||||
#define DRED_PADDED_STATE_DIM {padded_state_dim}
|
||||
|
||||
#define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels}
|
||||
|
||||
#define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)}
|
||||
|
||||
#define DRED_MAX_CONV_INPUTS {max(enc_max_conv_inputs, dec_max_conv_inputs)}
|
||||
|
||||
#define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs}
|
||||
|
||||
#define DRED_ENC_MAX_CONV_INPUTS {enc_max_conv_inputs}
|
||||
|
||||
#define DRED_DEC_MAX_RNN_NEURONS {dec_max_rnn_units}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
del constants_writer
|
||||
|
||||
|
||||
def numpy_export(args, model):
|
||||
|
||||
exchange_name_to_name = {
|
||||
'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1',
|
||||
'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2',
|
||||
'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3',
|
||||
'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4',
|
||||
'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5',
|
||||
'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1',
|
||||
'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2',
|
||||
'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1',
|
||||
'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2',
|
||||
'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3',
|
||||
'encoder_stack_layer9_conv' : 'core_encoder.module.conv1',
|
||||
'statistical_model_embedding' : 'statistical_model.quant_embedding',
|
||||
'decoder_state1_dense' : 'core_decoder.module.gru_1_init',
|
||||
'decoder_state2_dense' : 'core_decoder.module.gru_2_init',
|
||||
'decoder_state3_dense' : 'core_decoder.module.gru_3_init',
|
||||
'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1',
|
||||
'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2',
|
||||
'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3',
|
||||
'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4',
|
||||
'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5',
|
||||
'decoder_stack_layer9_dense' : 'core_decoder.module.output',
|
||||
'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1',
|
||||
'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2',
|
||||
'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3'
|
||||
}
|
||||
|
||||
name_to_exchange_name = {value : key for key, value in exchange_name_to_name.items()}
|
||||
|
||||
for name, exchange_name in name_to_exchange_name.items():
|
||||
print(f"printing layer {name}...")
|
||||
dump_torch_weights(os.path.join(args.output_dir, exchange_name), model.get_submodule(name))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
|
||||
# load model from checkpoint
|
||||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
model.apply(_remove_weight_norm)
|
||||
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(f"error: missing keys in state dict")
|
||||
|
||||
if len(unmatched_keys) > 0:
|
||||
print(f"warning: the following keys were unmatched {unmatched_keys}")
|
||||
|
||||
if args.format == 'C':
|
||||
c_export(args, model)
|
||||
elif args.format == 'numpy':
|
||||
numpy_export(args, model)
|
||||
else:
|
||||
raise ValueError(f'error: unknown export format {args.format}')
|
||||
212
managed_components/78__esp-opus/dnn/torch/rdovae/fec_encoder.py
Normal file
212
managed_components/78__esp-opus/dnn/torch/rdovae/fec_encoder.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe and Jean-Marc Valin */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ""
|
||||
|
||||
parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames')
|
||||
|
||||
parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)')
|
||||
parser.add_argument('checkpoint', metavar='<weights>', help='model checkpoint')
|
||||
parser.add_argument('q0', metavar='<quant level 0>', type=int, help='quantization level for most recent frame')
|
||||
parser.add_argument('q1', metavar='<quant level 1>', type=int, help='quantization level for oldest frame')
|
||||
parser.add_argument('output', type=str, help='output file (will be extended with .fec)')
|
||||
|
||||
parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)')
|
||||
parser.add_argument('--num-redundancy-frames', default=52, type=int, help='number of redundancy frames per packet (default 52)')
|
||||
parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)")
|
||||
parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)')
|
||||
parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
import torch
|
||||
|
||||
from rdovae import RDOVAE
|
||||
from packets import write_fec_packets
|
||||
|
||||
torch.set_num_threads(4)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
model.to("cpu")
|
||||
|
||||
lpc_order = 16
|
||||
|
||||
## prepare input signal
|
||||
# SILK frame size is 20ms and LPCNet subframes are 10ms
|
||||
subframe_size = 160
|
||||
frame_size = 2 * subframe_size
|
||||
|
||||
# 91 samples delay to align with SILK decoded frames
|
||||
silk_delay = 91
|
||||
|
||||
# prepend zeros to have enough history to produce the first package
|
||||
zero_history = (args.num_redundancy_frames - 1) * frame_size
|
||||
|
||||
# dump data has a (feature) delay of 10ms
|
||||
dump_data_delay = 160
|
||||
|
||||
total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay
|
||||
|
||||
# load signal
|
||||
if args.input.endswith('.raw') or args.input.endswith('.pcm'):
|
||||
signal = np.fromfile(args.input, dtype='int16')
|
||||
|
||||
elif args.input.endswith('.wav'):
|
||||
fs, signal = wavfile.read(args.input)
|
||||
else:
|
||||
raise ValueError(f'unknown input signal format: {args.input}')
|
||||
|
||||
# fill up last frame with zeros
|
||||
padded_signal_length = len(signal) + total_delay
|
||||
tail = padded_signal_length % frame_size
|
||||
right_padding = (frame_size - tail) % frame_size
|
||||
|
||||
signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16)))
|
||||
|
||||
padded_signal_file = os.path.splitext(args.input)[0] + '_padded.raw'
|
||||
signal.tofile(padded_signal_file)
|
||||
|
||||
# write signal and call dump_data to create features
|
||||
|
||||
feature_file = os.path.splitext(args.input)[0] + '_features.f32'
|
||||
command = f"{args.dump_data} -test {padded_signal_file} {feature_file}"
|
||||
r = subprocess.run(command, shell=True)
|
||||
if r.returncode != 0:
|
||||
raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}")
|
||||
|
||||
# load features
|
||||
nb_features = model.feature_dim + lpc_order
|
||||
nb_used_features = model.feature_dim
|
||||
|
||||
# load features
|
||||
features = np.fromfile(feature_file, dtype='float32')
|
||||
num_subframes = len(features) // nb_features
|
||||
num_subframes = 2 * (num_subframes // 2)
|
||||
num_frames = num_subframes // 2
|
||||
|
||||
features = np.reshape(features, (1, -1, nb_features))
|
||||
features = features[:, :, :nb_used_features]
|
||||
features = features[:, :num_subframes, :]
|
||||
|
||||
# quant_ids in reverse decoding order
|
||||
quant_ids = torch.round((args.q1 + (args.q0 - args.q1) * torch.arange(args.num_redundancy_frames // 2) / (args.num_redundancy_frames // 2 - 1))).long()
|
||||
|
||||
print(f"using quantization levels {quant_ids}...")
|
||||
|
||||
# convert input to torch tensors
|
||||
features = torch.from_numpy(features)
|
||||
|
||||
|
||||
# run encoder
|
||||
print("running fec encoder...")
|
||||
with torch.no_grad():
|
||||
|
||||
# encoding
|
||||
z, states, state_size = model.encode(features)
|
||||
|
||||
|
||||
# decoder on packet chunks
|
||||
input_length = args.num_redundancy_frames // 2
|
||||
offset = args.num_redundancy_frames - 1
|
||||
|
||||
packets = []
|
||||
packet_sizes = []
|
||||
|
||||
for i in range(offset, num_frames):
|
||||
print(f"processing frame {i - offset}...")
|
||||
# quantize / unquantize latent vectors
|
||||
zi = torch.clone(z[:, i - 2 * input_length + 2: i + 1 : 2, :])
|
||||
zi, rates = model.quantize(zi, quant_ids)
|
||||
zi = model.unquantize(zi, quant_ids)
|
||||
|
||||
features = model.decode(zi, states[:, i : i + 1, :])
|
||||
packets.append(features.squeeze(0).numpy())
|
||||
packet_size = 8 * int((torch.sum(rates) + 7 + state_size) / 8)
|
||||
packet_sizes.append(packet_size)
|
||||
|
||||
|
||||
# write packets
|
||||
packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output
|
||||
write_fec_packets(packet_file, packets, packet_sizes)
|
||||
|
||||
|
||||
print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps")
|
||||
|
||||
# assemble features according to loss file
|
||||
if args.lossfile != None:
|
||||
num_packets = len(packets)
|
||||
loss = np.loadtxt(args.lossfile, dtype='int16')
|
||||
fec_out = np.zeros((num_packets * 2, packets[0].shape[-1]), dtype='float32')
|
||||
foffset = -2
|
||||
ptr = 0
|
||||
count = 2
|
||||
for i in range(num_packets):
|
||||
if (loss[i] == 0) or (i == num_packets - 1):
|
||||
|
||||
fec_out[ptr:ptr+count,:] = packets[i][foffset:, :]
|
||||
|
||||
ptr += count
|
||||
foffset = -2
|
||||
count = 2
|
||||
else:
|
||||
count += 2
|
||||
foffset -= 2
|
||||
|
||||
fec_out_full = np.zeros((fec_out.shape[0], 36), dtype=np.float32)
|
||||
fec_out_full[:, : fec_out.shape[-1]] = fec_out
|
||||
|
||||
fec_out_full.tofile(packet_file[:-4] + f'_fec.f32')
|
||||
|
||||
|
||||
if args.debug_output:
|
||||
import itertools
|
||||
|
||||
batches = [4]
|
||||
offsets = [0, 2 * args.num_redundancy_frames - 4]
|
||||
|
||||
# sanity checks
|
||||
# 1. concatenate features at offset 0
|
||||
for batch, offset in itertools.product(batches, offsets):
|
||||
|
||||
stop = packets[0].shape[1] - offset
|
||||
test_features = np.concatenate([packet[stop - batch: stop, :] for packet in packets[::batch//2]], axis=0)
|
||||
|
||||
test_features_full = np.zeros((test_features.shape[0], nb_features), dtype=np.float32)
|
||||
test_features_full[:, :nb_used_features] = test_features[:, :]
|
||||
|
||||
print(f"writing debug output {packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32'}")
|
||||
test_features_full.tofile(packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32')
|
||||
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ""
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('exchange_folder', type=str, help='exchange folder path')
|
||||
parser.add_argument('output', type=str, help='path to output model checkpoint')
|
||||
|
||||
model_group = parser.add_argument_group(title="model parameters")
|
||||
model_group.add_argument('--num-features', type=int, help="number of features, default: 20", default=20)
|
||||
model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
|
||||
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
|
||||
model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
|
||||
model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
|
||||
model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 40", default=40)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import torch
|
||||
from rdovae import RDOVAE
|
||||
from wexchange.torch import load_torch_weights
|
||||
|
||||
exchange_name_to_name = {
|
||||
'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1',
|
||||
'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2',
|
||||
'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3',
|
||||
'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4',
|
||||
'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5',
|
||||
'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1',
|
||||
'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2',
|
||||
'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1',
|
||||
'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2',
|
||||
'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3',
|
||||
'encoder_stack_layer9_conv' : 'core_encoder.module.conv1',
|
||||
'statistical_model_embedding' : 'statistical_model.quant_embedding',
|
||||
'decoder_state1_dense' : 'core_decoder.module.gru_1_init',
|
||||
'decoder_state2_dense' : 'core_decoder.module.gru_2_init',
|
||||
'decoder_state3_dense' : 'core_decoder.module.gru_3_init',
|
||||
'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1',
|
||||
'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2',
|
||||
'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3',
|
||||
'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4',
|
||||
'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5',
|
||||
'decoder_stack_layer9_dense' : 'core_decoder.module.output',
|
||||
'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1',
|
||||
'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2',
|
||||
'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3'
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
checkpoint = dict()
|
||||
|
||||
# parameters
|
||||
num_features = args.num_features
|
||||
latent_dim = args.latent_dim
|
||||
quant_levels = args.quant_levels
|
||||
cond_size = args.cond_size
|
||||
cond_size2 = args.cond_size2
|
||||
state_dim = args.state_dim
|
||||
|
||||
|
||||
# model
|
||||
checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
|
||||
checkpoint['model_kwargs'] = {'state_dim': state_dim}
|
||||
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
|
||||
dense_layer_names = [
|
||||
'encoder_stack_layer1_dense',
|
||||
'encoder_stack_layer3_dense',
|
||||
'encoder_stack_layer5_dense',
|
||||
'encoder_stack_layer7_dense',
|
||||
'encoder_stack_layer8_dense',
|
||||
'encoder_state_layer1_dense',
|
||||
'encoder_state_layer2_dense',
|
||||
'decoder_state1_dense',
|
||||
'decoder_state2_dense',
|
||||
'decoder_state3_dense',
|
||||
'decoder_stack_layer1_dense',
|
||||
'decoder_stack_layer3_dense',
|
||||
'decoder_stack_layer5_dense',
|
||||
'decoder_stack_layer7_dense',
|
||||
'decoder_stack_layer8_dense',
|
||||
'decoder_stack_layer9_dense'
|
||||
]
|
||||
|
||||
gru_layer_names = [
|
||||
'encoder_stack_layer2_gru',
|
||||
'encoder_stack_layer4_gru',
|
||||
'encoder_stack_layer6_gru',
|
||||
'decoder_stack_layer2_gru',
|
||||
'decoder_stack_layer4_gru',
|
||||
'decoder_stack_layer6_gru'
|
||||
]
|
||||
|
||||
conv1d_layer_names = [
|
||||
'encoder_stack_layer9_conv'
|
||||
]
|
||||
|
||||
embedding_layer_names = [
|
||||
'statistical_model_embedding'
|
||||
]
|
||||
|
||||
for name in dense_layer_names + gru_layer_names + conv1d_layer_names + embedding_layer_names:
|
||||
print(f"loading weights for layer {exchange_name_to_name[name]}")
|
||||
layer = model.get_submodule(exchange_name_to_name[name])
|
||||
load_torch_weights(os.path.join(args.exchange_folder, name), layer)
|
||||
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
|
||||
torch.save(checkpoint, args.output)
|
||||
@@ -0,0 +1 @@
|
||||
from .fec_packets import write_fec_packets, read_fec_packets
|
||||
@@ -0,0 +1,142 @@
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <inttypes.h>
|
||||
|
||||
#include "fec_packets.h"
|
||||
|
||||
int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index)
|
||||
{
|
||||
|
||||
int16_t version;
|
||||
int16_t header_size;
|
||||
int16_t num_packets;
|
||||
int16_t packet_size;
|
||||
int16_t subframe_size;
|
||||
int16_t subframes_per_packet;
|
||||
int16_t num_features;
|
||||
long offset;
|
||||
|
||||
FILE *fid = fopen(filename, "rb");
|
||||
|
||||
/* read header */
|
||||
if (fread(&version, sizeof(version), 1, fid) != 1) goto error;
|
||||
if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error;
|
||||
if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error;
|
||||
if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error;
|
||||
if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error;
|
||||
if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error;
|
||||
if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error;
|
||||
|
||||
/* check if indices are valid */
|
||||
if (packet_index >= num_packets || subframe_index >= subframes_per_packet)
|
||||
{
|
||||
fprintf(stderr, "get_fec_frame: index out of bounds\n");
|
||||
goto error;
|
||||
}
|
||||
|
||||
/* calculate offset in file (+ 2 is for rate) */
|
||||
offset = header_size + packet_index * packet_size + 2 + subframe_index * subframe_size;
|
||||
fseek(fid, offset, SEEK_SET);
|
||||
|
||||
/* read features */
|
||||
if (fread(features, sizeof(*features), num_features, fid) != num_features) goto error;
|
||||
|
||||
fclose(fid);
|
||||
return 0;
|
||||
|
||||
error:
|
||||
fclose(fid);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int get_fec_rate(const char * const filename, int packet_index)
|
||||
{
|
||||
int16_t version;
|
||||
int16_t header_size;
|
||||
int16_t num_packets;
|
||||
int16_t packet_size;
|
||||
int16_t subframe_size;
|
||||
int16_t subframes_per_packet;
|
||||
int16_t num_features;
|
||||
long offset;
|
||||
int16_t rate;
|
||||
|
||||
FILE *fid = fopen(filename, "rb");
|
||||
|
||||
/* read header */
|
||||
if (fread(&version, sizeof(version), 1, fid) != 1) goto error;
|
||||
if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error;
|
||||
if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error;
|
||||
if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error;
|
||||
if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error;
|
||||
if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error;
|
||||
if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error;
|
||||
|
||||
/* check if indices are valid */
|
||||
if (packet_index >= num_packets)
|
||||
{
|
||||
fprintf(stderr, "get_fec_rate: index out of bounds\n");
|
||||
goto error;
|
||||
}
|
||||
|
||||
/* calculate offset in file (+ 2 is for rate) */
|
||||
offset = header_size + packet_index * packet_size;
|
||||
fseek(fid, offset, SEEK_SET);
|
||||
|
||||
/* read rate */
|
||||
if (fread(&rate, sizeof(rate), 1, fid) != 1) goto error;
|
||||
|
||||
fclose(fid);
|
||||
return (int) rate;
|
||||
|
||||
error:
|
||||
fclose(fid);
|
||||
return -1;
|
||||
}
|
||||
|
||||
#if 0
|
||||
int main()
|
||||
{
|
||||
float features[20];
|
||||
int i;
|
||||
|
||||
if (get_fec_frame("../test.fec", &features[0], 0, 127))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (i = 0; i < 20; i ++)
|
||||
{
|
||||
printf("%d %f\n", i, features[i]);
|
||||
}
|
||||
|
||||
printf("rate: %d\n", get_fec_rate("../test.fec", 0));
|
||||
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,34 @@
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
#ifndef _FEC_PACKETS_H
|
||||
#define _FEC_PACKETS_H
|
||||
|
||||
int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index);
|
||||
int get_fec_rate(const char * const filename, int packet_index);
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def write_fec_packets(filename, packets, rates=None):
|
||||
""" writes packets in binary format """
|
||||
|
||||
assert np.dtype(np.float32).itemsize == 4
|
||||
assert np.dtype(np.int16).itemsize == 2
|
||||
|
||||
# derive some sizes
|
||||
num_packets = len(packets)
|
||||
subframes_per_packet = packets[0].shape[-2]
|
||||
num_features = packets[0].shape[-1]
|
||||
|
||||
# size of float is 4
|
||||
subframe_size = num_features * 4
|
||||
packet_size = subframe_size * subframes_per_packet + 2 # two bytes for rate
|
||||
|
||||
version = 1
|
||||
# header size (version, header_size, num_packets, packet_size, subframe_size, subrames_per_packet, num_features)
|
||||
header_size = 14
|
||||
|
||||
with open(filename, 'wb') as f:
|
||||
|
||||
# header
|
||||
f.write(np.int16(version).tobytes())
|
||||
f.write(np.int16(header_size).tobytes())
|
||||
f.write(np.int16(num_packets).tobytes())
|
||||
f.write(np.int16(packet_size).tobytes())
|
||||
f.write(np.int16(subframe_size).tobytes())
|
||||
f.write(np.int16(subframes_per_packet).tobytes())
|
||||
f.write(np.int16(num_features).tobytes())
|
||||
|
||||
# packets
|
||||
for i, packet in enumerate(packets):
|
||||
if type(rates) == type(None):
|
||||
rate = 0
|
||||
else:
|
||||
rate = rates[i]
|
||||
|
||||
f.write(np.int16(rate).tobytes())
|
||||
|
||||
features = np.flip(packet, axis=-2)
|
||||
f.write(features.astype(np.float32).tobytes())
|
||||
|
||||
|
||||
def read_fec_packets(filename):
|
||||
""" reads packets from binary format """
|
||||
|
||||
assert np.dtype(np.float32).itemsize == 4
|
||||
assert np.dtype(np.int16).itemsize == 2
|
||||
|
||||
with open(filename, 'rb') as f:
|
||||
|
||||
# header
|
||||
version = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
header_size = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
num_packets = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
packet_size = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
subframe_size = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
subframes_per_packet = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
num_features = np.frombuffer(f.read(2), dtype=np.int16).item()
|
||||
|
||||
dummy_features = np.zeros((subframes_per_packet, num_features), dtype=np.float32)
|
||||
|
||||
# packets
|
||||
rates = []
|
||||
packets = []
|
||||
for i in range(num_packets):
|
||||
|
||||
rate = np.frombuffer(f.read(2), dtype=np.int16).item
|
||||
rates.append(rate)
|
||||
|
||||
features = np.reshape(np.frombuffer(f.read(subframe_size * subframes_per_packet), dtype=np.float32), dummy_features.shape)
|
||||
packet = np.flip(features, axis=-2)
|
||||
packets.append(packet)
|
||||
|
||||
return packets
|
||||
@@ -0,0 +1,7 @@
|
||||
#!/bin/sh
|
||||
|
||||
cd datasets
|
||||
|
||||
#parallel -j +2 'unzip -n {}' ::: *.zip
|
||||
|
||||
find . -name "*.wav" | parallel -k -j 20 'sox --no-dither {} -t sw -r 16000 -c 1 -' > ../all_speech.sw
|
||||
@@ -0,0 +1,2 @@
|
||||
from .rdovae import RDOVAE, distortion_loss, hard_rate_estimate, soft_rate_estimate
|
||||
from .dataset import RDOVAEDataset
|
||||
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class RDOVAEDataset(torch.utils.data.Dataset):
|
||||
def __init__(self,
|
||||
feature_file,
|
||||
sequence_length,
|
||||
num_used_features=20,
|
||||
num_features=36,
|
||||
lambda_min=0.0002,
|
||||
lambda_max=0.0135,
|
||||
quant_levels=16,
|
||||
enc_stride=2):
|
||||
|
||||
self.sequence_length = sequence_length
|
||||
self.lambda_min = lambda_min
|
||||
self.lambda_max = lambda_max
|
||||
self.enc_stride = enc_stride
|
||||
self.quant_levels = quant_levels
|
||||
self.denominator = (quant_levels - 1) / np.log(lambda_max / lambda_min)
|
||||
|
||||
if sequence_length % enc_stride:
|
||||
raise ValueError(f"RDOVAEDataset.__init__: enc_stride {enc_stride} does not divide sequence length {sequence_length}")
|
||||
|
||||
self.features = np.reshape(np.fromfile(feature_file, dtype=np.float32), (-1, num_features))
|
||||
self.features = self.features[:, :num_used_features]
|
||||
self.num_sequences = self.features.shape[0] // sequence_length
|
||||
|
||||
def __len__(self):
|
||||
return self.num_sequences
|
||||
|
||||
def __getitem__(self, index):
|
||||
features = self.features[index * self.sequence_length: (index + 1) * self.sequence_length, :]
|
||||
q_ids = np.random.randint(0, self.quant_levels, (1)).astype(np.int64)
|
||||
q_ids = np.repeat(q_ids, self.sequence_length // self.enc_stride, axis=0)
|
||||
rate_lambda = self.lambda_min * np.exp(q_ids.astype(np.float32) / self.denominator).astype(np.float32)
|
||||
|
||||
return features, rate_lambda, q_ids
|
||||
@@ -0,0 +1,752 @@
|
||||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
""" Pytorch implementations of rate distortion optimized variational autoencoder """
|
||||
|
||||
import math as m
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import sys
|
||||
import os
|
||||
source_dir = os.path.split(os.path.abspath(__file__))[0]
|
||||
sys.path.append(os.path.join(source_dir, "../../lpcnet/"))
|
||||
from utils.sparsification import GRUSparsifier
|
||||
from torch.nn.utils import weight_norm
|
||||
sys.path.append(os.path.join(source_dir, "../../dnntools"))
|
||||
from dnntools.quantization import soft_quant
|
||||
|
||||
# Quantization and rate related utily functions
|
||||
|
||||
def soft_pvq(x, k):
|
||||
""" soft pyramid vector quantizer """
|
||||
|
||||
# L2 normalization
|
||||
x_norm2 = x / (1e-15 + torch.norm(x, dim=-1, keepdim=True))
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
# quantization loop, no need to track gradients here
|
||||
x_norm1 = x / torch.sum(torch.abs(x), dim=-1, keepdim=True)
|
||||
|
||||
# set initial scaling factor to k
|
||||
scale_factor = k
|
||||
x_scaled = scale_factor * x_norm1
|
||||
x_quant = torch.round(x_scaled)
|
||||
|
||||
# we aim for ||x_quant||_L1 = k
|
||||
for _ in range(10):
|
||||
# remove signs and calculate L1 norm
|
||||
abs_x_quant = torch.abs(x_quant)
|
||||
abs_x_scaled = torch.abs(x_scaled)
|
||||
l1_x_quant = torch.sum(abs_x_quant, axis=-1)
|
||||
|
||||
# increase, where target is too small and decrease, where target is too large
|
||||
plus = 1.0001 * torch.min((abs_x_quant + 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
|
||||
minus = 0.9999 * torch.max((abs_x_quant - 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
|
||||
factor = torch.where(l1_x_quant > k, minus, plus)
|
||||
factor = torch.where(l1_x_quant == k, torch.ones_like(factor), factor)
|
||||
scale_factor = scale_factor * factor.unsqueeze(-1)
|
||||
|
||||
# update x
|
||||
x_scaled = scale_factor * x_norm1
|
||||
x_quant = torch.round(x_quant)
|
||||
|
||||
# L2 normalization of quantized x
|
||||
x_quant_norm2 = x_quant / (1e-15 + torch.norm(x_quant, dim=-1, keepdim=True))
|
||||
quantization_error = x_quant_norm2 - x_norm2
|
||||
|
||||
return x_norm2 + quantization_error.detach()
|
||||
|
||||
def cache_parameters(func):
|
||||
cache = dict()
|
||||
def cached_func(*args):
|
||||
if args in cache:
|
||||
return cache[args]
|
||||
else:
|
||||
cache[args] = func(*args)
|
||||
|
||||
return cache[args]
|
||||
return cached_func
|
||||
|
||||
@cache_parameters
|
||||
def pvq_codebook_size(n, k):
|
||||
|
||||
if k == 0:
|
||||
return 1
|
||||
|
||||
if n == 0:
|
||||
return 0
|
||||
|
||||
return pvq_codebook_size(n - 1, k) + pvq_codebook_size(n, k - 1) + pvq_codebook_size(n - 1, k - 1)
|
||||
|
||||
|
||||
def soft_rate_estimate(z, r, reduce=True):
|
||||
""" rate approximation with dependent theta Eq. (7)"""
|
||||
|
||||
rate = torch.sum(
|
||||
- torch.log2((1 - r)/(1 + r) * r ** torch.abs(z) + 1e-6),
|
||||
dim=-1
|
||||
)
|
||||
|
||||
if reduce:
|
||||
rate = torch.mean(rate)
|
||||
|
||||
return rate
|
||||
|
||||
|
||||
def hard_rate_estimate(z, r, theta, reduce=True):
|
||||
""" hard rate approximation """
|
||||
|
||||
z_q = torch.round(z)
|
||||
p0 = 1 - r ** (0.5 + 0.5 * theta)
|
||||
alpha = torch.relu(1 - torch.abs(z_q)) ** 2
|
||||
rate = - torch.sum(
|
||||
(alpha * torch.log2(p0 * r ** torch.abs(z_q) + 1e-6)
|
||||
+ (1 - alpha) * torch.log2(0.5 * (1 - p0) * (1 - r) * r ** (torch.abs(z_q) - 1) + 1e-6)),
|
||||
dim=-1
|
||||
)
|
||||
|
||||
if reduce:
|
||||
rate = torch.mean(rate)
|
||||
|
||||
return rate
|
||||
|
||||
|
||||
|
||||
def soft_dead_zone(x, dead_zone):
|
||||
""" approximates application of a dead zone to x """
|
||||
d = dead_zone * 0.05
|
||||
return x - d * torch.tanh(x / (0.1 + d))
|
||||
|
||||
|
||||
def hard_quantize(x):
|
||||
""" round with copy gradient trick """
|
||||
return x + (torch.round(x) - x).detach()
|
||||
|
||||
|
||||
def noise_quantize(x):
|
||||
""" simulates quantization with addition of random uniform noise """
|
||||
return x + (torch.rand_like(x) - 0.5)
|
||||
|
||||
|
||||
# loss functions
|
||||
|
||||
|
||||
def distortion_loss(y_true, y_pred, rate_lambda=None):
|
||||
""" custom distortion loss for LPCNet features """
|
||||
|
||||
if y_true.size(-1) != 20:
|
||||
raise ValueError('distortion loss is designed to work with 20 features')
|
||||
|
||||
ceps_error = y_pred[..., :18] - y_true[..., :18]
|
||||
pitch_error = 2*(y_pred[..., 18:19] - y_true[..., 18:19])
|
||||
corr_error = y_pred[..., 19:] - y_true[..., 19:]
|
||||
pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2
|
||||
|
||||
loss = torch.mean(ceps_error ** 2 + (10/18) * torch.abs(pitch_error) * pitch_weight + (1/18) * corr_error ** 2, dim=-1)
|
||||
|
||||
if type(rate_lambda) != type(None):
|
||||
loss = loss / torch.sqrt(rate_lambda)
|
||||
|
||||
loss = torch.mean(loss)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# sampling functions
|
||||
|
||||
import random
|
||||
|
||||
|
||||
def random_split(start, stop, num_splits=3, min_len=3):
|
||||
get_min_len = lambda x : min([x[i+1] - x[i] for i in range(len(x) - 1)])
|
||||
candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
|
||||
|
||||
while get_min_len(candidate) < min_len:
|
||||
candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
|
||||
|
||||
return candidate
|
||||
|
||||
|
||||
|
||||
# weight initialization and clipping
|
||||
def init_weights(module):
|
||||
|
||||
if isinstance(module, nn.GRU):
|
||||
for p in module.named_parameters():
|
||||
if p[0].startswith('weight_hh_'):
|
||||
nn.init.orthogonal_(p[1])
|
||||
|
||||
|
||||
def weight_clip_factory(max_value):
|
||||
""" weight clipping function concerning sum of abs values of adjecent weights """
|
||||
def clip_weight_(w):
|
||||
stop = w.size(1)
|
||||
# omit last column if stop is odd
|
||||
if stop % 2:
|
||||
stop -= 1
|
||||
max_values = max_value * torch.ones_like(w[:, :stop])
|
||||
factor = max_value / torch.maximum(max_values,
|
||||
torch.repeat_interleave(
|
||||
torch.abs(w[:, :stop:2]) + torch.abs(w[:, 1:stop:2]),
|
||||
2,
|
||||
1))
|
||||
with torch.no_grad():
|
||||
w[:, :stop] *= factor
|
||||
|
||||
def clip_weights(module):
|
||||
if isinstance(module, nn.GRU) or isinstance(module, nn.Linear):
|
||||
for name, w in module.named_parameters():
|
||||
if name.startswith('weight'):
|
||||
clip_weight_(w)
|
||||
|
||||
return clip_weights
|
||||
|
||||
def n(x):
|
||||
return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
|
||||
|
||||
# RDOVAE module and submodules
|
||||
|
||||
sparsify_start = 12000
|
||||
sparsify_stop = 24000
|
||||
sparsify_interval = 100
|
||||
sparsify_exponent = 3
|
||||
#sparsify_start = 0
|
||||
#sparsify_stop = 0
|
||||
|
||||
sparse_params1 = {
|
||||
# 'W_hr' : (1.0, [8, 4], True),
|
||||
# 'W_hz' : (1.0, [8, 4], True),
|
||||
# 'W_hn' : (1.0, [8, 4], True),
|
||||
'W_ir' : (0.6, [8, 4], False),
|
||||
'W_iz' : (0.4, [8, 4], False),
|
||||
'W_in' : (0.8, [8, 4], False)
|
||||
}
|
||||
|
||||
sparse_params2 = {
|
||||
# 'W_hr' : (1.0, [8, 4], True),
|
||||
# 'W_hz' : (1.0, [8, 4], True),
|
||||
# 'W_hn' : (1.0, [8, 4], True),
|
||||
'W_ir' : (0.3, [8, 4], False),
|
||||
'W_iz' : (0.2, [8, 4], False),
|
||||
'W_in' : (0.4, [8, 4], False)
|
||||
}
|
||||
|
||||
|
||||
class MyConv(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dilation=1, softquant=False):
|
||||
super(MyConv, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.dilation=dilation
|
||||
self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation)
|
||||
|
||||
if softquant:
|
||||
self.conv = soft_quant(self.conv)
|
||||
|
||||
def forward(self, x, state=None):
|
||||
device = x.device
|
||||
conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1)
|
||||
return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, feat_size, softquant=False):
|
||||
super(GLU, self).__init__()
|
||||
|
||||
torch.manual_seed(5)
|
||||
|
||||
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
|
||||
|
||||
if softquant:
|
||||
self.gate = soft_quant(self.gate)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
|
||||
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
out = x * torch.sigmoid(self.gate(x))
|
||||
|
||||
return out
|
||||
|
||||
class CoreEncoder(nn.Module):
|
||||
STATE_HIDDEN = 128
|
||||
FRAMES_PER_STEP = 2
|
||||
CONV_KERNEL_SIZE = 4
|
||||
|
||||
def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24, softquant=False):
|
||||
""" core encoder for RDOVAE
|
||||
|
||||
Computes latents, initial states, and rate estimates from features and lambda parameter
|
||||
|
||||
"""
|
||||
|
||||
super(CoreEncoder, self).__init__()
|
||||
|
||||
# hyper parameters
|
||||
self.feature_dim = feature_dim
|
||||
self.output_dim = output_dim
|
||||
self.cond_size = cond_size
|
||||
self.cond_size2 = cond_size2
|
||||
self.state_size = state_size
|
||||
|
||||
# derived parameters
|
||||
self.input_dim = self.FRAMES_PER_STEP * self.feature_dim
|
||||
|
||||
# layers
|
||||
self.dense_1 = nn.Linear(self.input_dim, 64)
|
||||
self.gru1 = nn.GRU(64, 64, batch_first=True)
|
||||
self.conv1 = MyConv(128, 96, softquant=True)
|
||||
self.gru2 = nn.GRU(224, 64, batch_first=True)
|
||||
self.conv2 = MyConv(288, 96, dilation=2, softquant=True)
|
||||
self.gru3 = nn.GRU(384, 64, batch_first=True)
|
||||
self.conv3 = MyConv(448, 96, dilation=2, softquant=True)
|
||||
self.gru4 = nn.GRU(544, 64, batch_first=True)
|
||||
self.conv4 = MyConv(608, 96, dilation=2, softquant=True)
|
||||
self.gru5 = nn.GRU(704, 64, batch_first=True)
|
||||
self.conv5 = MyConv(768, 96, dilation=2, softquant=True)
|
||||
|
||||
self.z_dense = nn.Linear(864, self.output_dim)
|
||||
|
||||
|
||||
self.state_dense_1 = nn.Linear(864, self.STATE_HIDDEN)
|
||||
|
||||
self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size)
|
||||
nb_params = sum(p.numel() for p in self.parameters())
|
||||
print(f"encoder: {nb_params} weights")
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
if softquant:
|
||||
self.gru1 = soft_quant(self.gru1, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
self.gru2 = soft_quant(self.gru2, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
self.gru3 = soft_quant(self.gru3, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
self.gru4 = soft_quant(self.gru4, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
self.gru5 = soft_quant(self.gru5, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
self.z_dense = soft_quant(self.z_dense)
|
||||
self.state_dense_1 = soft_quant(self.state_dense_1)
|
||||
self.state_dense_2 = soft_quant(self.state_dense_2)
|
||||
|
||||
|
||||
def forward(self, features):
|
||||
|
||||
# reshape features
|
||||
x = torch.reshape(features, (features.size(0), features.size(1) // self.FRAMES_PER_STEP, self.FRAMES_PER_STEP * features.size(2)))
|
||||
|
||||
batch = x.size(0)
|
||||
device = x.device
|
||||
|
||||
# run encoding layer stack
|
||||
x = n(torch.tanh(self.dense_1(x)))
|
||||
x = torch.cat([x, n(self.gru1(x)[0])], -1)
|
||||
x = torch.cat([x, n(self.conv1(x))], -1)
|
||||
x = torch.cat([x, n(self.gru2(x)[0])], -1)
|
||||
x = torch.cat([x, n(self.conv2(x))], -1)
|
||||
x = torch.cat([x, n(self.gru3(x)[0])], -1)
|
||||
x = torch.cat([x, n(self.conv3(x))], -1)
|
||||
x = torch.cat([x, n(self.gru4(x)[0])], -1)
|
||||
x = torch.cat([x, n(self.conv4(x))], -1)
|
||||
x = torch.cat([x, n(self.gru5(x)[0])], -1)
|
||||
x = torch.cat([x, n(self.conv5(x))], -1)
|
||||
z = self.z_dense(x)
|
||||
|
||||
# init state for decoder
|
||||
states = torch.tanh(self.state_dense_1(x))
|
||||
states = self.state_dense_2(states)
|
||||
|
||||
return z, states
|
||||
|
||||
|
||||
|
||||
|
||||
class CoreDecoder(nn.Module):
|
||||
|
||||
FRAMES_PER_STEP = 4
|
||||
|
||||
def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24, softquant=False):
|
||||
""" core decoder for RDOVAE
|
||||
|
||||
Computes features from latents, initial state, and quantization index
|
||||
|
||||
"""
|
||||
|
||||
super(CoreDecoder, self).__init__()
|
||||
|
||||
# hyper parameters
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.cond_size = cond_size
|
||||
self.cond_size2 = cond_size2
|
||||
self.state_size = state_size
|
||||
|
||||
self.input_size = self.input_dim
|
||||
|
||||
# layers
|
||||
self.dense_1 = nn.Linear(self.input_size, 96)
|
||||
self.gru1 = nn.GRU(96, 96, batch_first=True)
|
||||
self.conv1 = MyConv(192, 32, softquant=softquant)
|
||||
self.gru2 = nn.GRU(224, 96, batch_first=True)
|
||||
self.conv2 = MyConv(320, 32, softquant=softquant)
|
||||
self.gru3 = nn.GRU(352, 96, batch_first=True)
|
||||
self.conv3 = MyConv(448, 32, softquant=softquant)
|
||||
self.gru4 = nn.GRU(480, 96, batch_first=True)
|
||||
self.conv4 = MyConv(576, 32, softquant=softquant)
|
||||
self.gru5 = nn.GRU(608, 96, batch_first=True)
|
||||
self.conv5 = MyConv(704, 32, softquant=softquant)
|
||||
self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim)
|
||||
self.glu1 = GLU(96, softquant=softquant)
|
||||
self.glu2 = GLU(96, softquant=softquant)
|
||||
self.glu3 = GLU(96, softquant=softquant)
|
||||
self.glu4 = GLU(96, softquant=softquant)
|
||||
self.glu5 = GLU(96, softquant=softquant)
|
||||
self.hidden_init = nn.Linear(self.state_size, 128)
|
||||
self.gru_init = nn.Linear(128, 480)
|
||||
|
||||
nb_params = sum(p.numel() for p in self.parameters())
|
||||
print(f"decoder: {nb_params} weights")
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
self.sparsifier = []
|
||||
self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
|
||||
self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
|
||||
self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
|
||||
self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
|
||||
self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
|
||||
|
||||
if softquant:
|
||||
self.gru1 = soft_quant(self.gru1, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
self.gru2 = soft_quant(self.gru2, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
self.gru3 = soft_quant(self.gru3, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
self.gru4 = soft_quant(self.gru4, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
self.gru5 = soft_quant(self.gru5, names=['weight_hh_l0', 'weight_ih_l0'])
|
||||
self.output = soft_quant(self.output)
|
||||
self.gru_init = soft_quant(self.gru_init)
|
||||
|
||||
def sparsify(self):
|
||||
for sparsifier in self.sparsifier:
|
||||
sparsifier.step()
|
||||
|
||||
def forward(self, z, initial_state):
|
||||
|
||||
hidden = torch.tanh(self.hidden_init(initial_state))
|
||||
gru_state = torch.tanh(self.gru_init(hidden).permute(1, 0, 2))
|
||||
h1_state = gru_state[:,:,:96].contiguous()
|
||||
h2_state = gru_state[:,:,96:192].contiguous()
|
||||
h3_state = gru_state[:,:,192:288].contiguous()
|
||||
h4_state = gru_state[:,:,288:384].contiguous()
|
||||
h5_state = gru_state[:,:,384:].contiguous()
|
||||
|
||||
# run decoding layer stack
|
||||
x = n(torch.tanh(self.dense_1(z)))
|
||||
|
||||
x = torch.cat([x, n(self.glu1(n(self.gru1(x, h1_state)[0])))], -1)
|
||||
x = torch.cat([x, n(self.conv1(x))], -1)
|
||||
x = torch.cat([x, n(self.glu2(n(self.gru2(x, h2_state)[0])))], -1)
|
||||
x = torch.cat([x, n(self.conv2(x))], -1)
|
||||
x = torch.cat([x, n(self.glu3(n(self.gru3(x, h3_state)[0])))], -1)
|
||||
x = torch.cat([x, n(self.conv3(x))], -1)
|
||||
x = torch.cat([x, n(self.glu4(n(self.gru4(x, h4_state)[0])))], -1)
|
||||
x = torch.cat([x, n(self.conv4(x))], -1)
|
||||
x = torch.cat([x, n(self.glu5(n(self.gru5(x, h5_state)[0])))], -1)
|
||||
x = torch.cat([x, n(self.conv5(x))], -1)
|
||||
|
||||
# output layer and reshaping
|
||||
x10 = self.output(x)
|
||||
features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP))
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class StatisticalModel(nn.Module):
|
||||
def __init__(self, quant_levels, latent_dim, state_dim):
|
||||
""" Statistical model for latent space
|
||||
|
||||
Computes scaling, deadzone, r, and theta
|
||||
|
||||
"""
|
||||
|
||||
super(StatisticalModel, self).__init__()
|
||||
|
||||
# copy parameters
|
||||
self.latent_dim = latent_dim
|
||||
self.state_dim = state_dim
|
||||
self.total_dim = latent_dim + state_dim
|
||||
self.quant_levels = quant_levels
|
||||
self.embedding_dim = 6 * self.total_dim
|
||||
|
||||
# quantization embedding
|
||||
self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim)
|
||||
|
||||
# initialize embedding to 0
|
||||
with torch.no_grad():
|
||||
self.quant_embedding.weight[:] = 0
|
||||
|
||||
|
||||
def forward(self, quant_ids):
|
||||
""" takes quant_ids and returns statistical model parameters"""
|
||||
|
||||
x = self.quant_embedding(quant_ids)
|
||||
|
||||
# CAVE: theta_soft is not used anymore. Kick it out?
|
||||
quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim])
|
||||
dead_zone = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim])
|
||||
theta_soft = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim])
|
||||
r_soft = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim])
|
||||
theta_hard = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim])
|
||||
r_hard = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim])
|
||||
|
||||
|
||||
return {
|
||||
'quant_embedding' : x,
|
||||
'quant_scale' : quant_scale,
|
||||
'dead_zone' : dead_zone,
|
||||
'r_hard' : r_hard,
|
||||
'theta_hard' : theta_hard,
|
||||
'r_soft' : r_soft,
|
||||
'theta_soft' : theta_soft
|
||||
}
|
||||
|
||||
|
||||
class RDOVAE(nn.Module):
|
||||
def __init__(self,
|
||||
feature_dim,
|
||||
latent_dim,
|
||||
quant_levels,
|
||||
cond_size,
|
||||
cond_size2,
|
||||
state_dim=24,
|
||||
split_mode='split',
|
||||
chunks_per_offset=4,
|
||||
clip_weights=False,
|
||||
pvq_num_pulses=82,
|
||||
state_dropout_rate=0,
|
||||
softquant=False):
|
||||
|
||||
super(RDOVAE, self).__init__()
|
||||
|
||||
self.feature_dim = feature_dim
|
||||
self.latent_dim = latent_dim
|
||||
self.quant_levels = quant_levels
|
||||
self.cond_size = cond_size
|
||||
self.cond_size2 = cond_size2
|
||||
self.split_mode = split_mode
|
||||
self.chunks_per_offset = chunks_per_offset
|
||||
self.state_dim = state_dim
|
||||
self.pvq_num_pulses = pvq_num_pulses
|
||||
self.state_dropout_rate = state_dropout_rate
|
||||
|
||||
# submodules encoder and decoder share the statistical model
|
||||
self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim)
|
||||
self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim, softquant=softquant))
|
||||
self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim, softquant=softquant))
|
||||
|
||||
self.enc_stride = CoreEncoder.FRAMES_PER_STEP
|
||||
self.dec_stride = CoreDecoder.FRAMES_PER_STEP
|
||||
|
||||
if clip_weights:
|
||||
self.weight_clip_fn = weight_clip_factory(0.496)
|
||||
else:
|
||||
self.weight_clip_fn = None
|
||||
|
||||
if self.dec_stride % self.enc_stride != 0:
|
||||
raise ValueError(f"get_decoder_chunks_generic: encoder stride does not divide decoder stride")
|
||||
|
||||
def clip_weights(self):
|
||||
if not type(self.weight_clip_fn) == type(None):
|
||||
self.apply(self.weight_clip_fn)
|
||||
|
||||
def sparsify(self):
|
||||
#self.core_encoder.module.sparsify()
|
||||
self.core_decoder.module.sparsify()
|
||||
|
||||
def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
|
||||
|
||||
enc_stride = self.enc_stride
|
||||
dec_stride = self.dec_stride
|
||||
|
||||
stride = dec_stride // enc_stride
|
||||
|
||||
chunks = []
|
||||
|
||||
for offset in range(stride):
|
||||
# start is the smalles number = offset mod stride that decodes to a valid range
|
||||
start = offset
|
||||
while enc_stride * (start + 1) - dec_stride < 0:
|
||||
start += stride
|
||||
|
||||
# check if start is a valid index
|
||||
if start >= z_frames:
|
||||
raise ValueError("get_decoder_chunks_generic: range too small")
|
||||
|
||||
# stop is the smallest number outside [0, num_enc_frames] that's congruent to offset mod stride
|
||||
stop = z_frames - (z_frames % stride) + offset
|
||||
while stop < z_frames:
|
||||
stop += stride
|
||||
|
||||
# calculate split points
|
||||
length = (stop - start)
|
||||
if mode == 'split':
|
||||
split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop]
|
||||
elif mode == 'random_split':
|
||||
split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)]
|
||||
elif mode == 'skewed_split':
|
||||
split_points = [start + stride * int(i * length / 4 / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop]
|
||||
else:
|
||||
raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}")
|
||||
|
||||
|
||||
for i in range(chunks_per_offset):
|
||||
# (enc_frame_start, enc_frame_stop, enc_frame_stride, stride, feature_frame_start, feature_frame_stop)
|
||||
# encoder range(i, j, stride) maps to feature range(enc_stride * (i + 1) - dec_stride, enc_stride * j)
|
||||
# provided that i - j = 1 mod stride
|
||||
chunks.append({
|
||||
'z_start' : split_points[i],
|
||||
'z_stop' : split_points[i + 1] - stride + 1,
|
||||
'z_stride' : stride,
|
||||
'features_start' : enc_stride * (split_points[i] + 1) - dec_stride,
|
||||
'features_stop' : enc_stride * (split_points[i + 1] - stride + 1)
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def forward(self, features, q_id):
|
||||
|
||||
# calculate statistical model from quantization ID
|
||||
statistical_model = self.statistical_model(q_id)
|
||||
|
||||
# run encoder
|
||||
z, states = self.core_encoder(features)
|
||||
|
||||
# scaling, dead-zone and quantization
|
||||
z = z * statistical_model['quant_scale'][:,:,:self.latent_dim]
|
||||
z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim])
|
||||
|
||||
# quantization
|
||||
z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
|
||||
z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
|
||||
#states_q = soft_pvq(states, self.pvq_num_pulses)
|
||||
states = states * statistical_model['quant_scale'][:,:,self.latent_dim:]
|
||||
states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:])
|
||||
|
||||
states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
|
||||
states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
|
||||
|
||||
if self.state_dropout_rate > 0:
|
||||
drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
|
||||
mask = torch.ones_like(states_q)
|
||||
mask[drop] = 0
|
||||
states_q = states_q * mask
|
||||
|
||||
# decoder
|
||||
chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode, chunks_per_offset=self.chunks_per_offset)
|
||||
|
||||
outputs_hq = []
|
||||
outputs_sq = []
|
||||
for chunk in chunks:
|
||||
# decoder with hard quantized input
|
||||
z_dec_reverse = torch.flip(z_q[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
|
||||
dec_initial_state = states_q[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
|
||||
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
|
||||
outputs_hq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
|
||||
|
||||
|
||||
# decoder with soft quantized input
|
||||
z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
|
||||
dec_initial_state = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
|
||||
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
|
||||
outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
|
||||
|
||||
return {
|
||||
'outputs_hard_quant' : outputs_hq,
|
||||
'outputs_soft_quant' : outputs_sq,
|
||||
'z' : z,
|
||||
'states' : states,
|
||||
'statistical_model' : statistical_model
|
||||
}
|
||||
|
||||
def encode(self, features):
|
||||
""" encoder with quantization and rate estimation """
|
||||
|
||||
z, states = self.core_encoder(features)
|
||||
|
||||
# quantization of initial states
|
||||
states = soft_pvq(states, self.pvq_num_pulses)
|
||||
state_size = m.log2(pvq_codebook_size(self.state_dim, self.pvq_num_pulses))
|
||||
|
||||
return z, states, state_size
|
||||
|
||||
def decode(self, z, initial_state):
|
||||
""" decoder (flips sequences by itself) """
|
||||
|
||||
z_reverse = torch.flip(z, [1])
|
||||
features_reverse = self.core_decoder(z_reverse, initial_state)
|
||||
features = torch.flip(features_reverse, [1])
|
||||
|
||||
return features
|
||||
|
||||
def quantize(self, z, q_ids):
|
||||
""" quantization of latent vectors """
|
||||
|
||||
stats = self.statistical_model(q_ids)
|
||||
|
||||
zq = z * stats['quant_scale'][:self.latent_dim]
|
||||
zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim])
|
||||
zq = torch.round(zq)
|
||||
|
||||
sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False)
|
||||
|
||||
return zq, sizes
|
||||
|
||||
def unquantize(self, zq, q_ids):
|
||||
""" re-scaling of latent vector """
|
||||
|
||||
stats = self.statistical_model(q_ids)
|
||||
|
||||
z = zq / stats['quant_scale'][:,:,:self.latent_dim]
|
||||
|
||||
return z
|
||||
|
||||
def freeze_model(self):
|
||||
|
||||
# freeze all parameters
|
||||
for p in self.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
for p in self.statistical_model.parameters():
|
||||
p.requires_grad = True
|
||||
@@ -0,0 +1,4 @@
|
||||
numpy
|
||||
scipy
|
||||
torch
|
||||
tqdm
|
||||
290
managed_components/78__esp-opus/dnn/torch/rdovae/train_rdovae.py
Normal file
290
managed_components/78__esp-opus/dnn/torch/rdovae/train_rdovae.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
/* Copyright (c) 2022 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from rdovae import RDOVAE, RDOVAEDataset, distortion_loss, hard_rate_estimate, soft_rate_estimate
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
|
||||
parser.add_argument('output', type=str, help='path to output folder')
|
||||
|
||||
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="")
|
||||
|
||||
|
||||
model_group = parser.add_argument_group(title="model parameters")
|
||||
model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
|
||||
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
|
||||
model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
|
||||
model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
|
||||
model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 16", default=16)
|
||||
model_group.add_argument('--lambda-min', type=float, help="minimal value for rate lambda, default: 0.0002", default=2e-4)
|
||||
model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104)
|
||||
model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82)
|
||||
model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0)
|
||||
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
|
||||
|
||||
training_group = parser.add_argument_group(title="training parameters")
|
||||
training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
|
||||
training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4)
|
||||
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100)
|
||||
training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by chunks_per_offset, default: 400', default=400)
|
||||
training_group.add_argument('--chunks-per-offset', type=int, help='chunks per offset', default=4)
|
||||
training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5)
|
||||
training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split', 'skewed_split'], help='splitting mode for decoder input, default: split', default='split')
|
||||
training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames')
|
||||
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
|
||||
training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# set visible devices
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
|
||||
|
||||
# checkpoints
|
||||
checkpoint_dir = os.path.join(args.output, 'checkpoints')
|
||||
checkpoint = dict()
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# training parameters
|
||||
batch_size = args.batch_size
|
||||
lr = args.lr
|
||||
epochs = args.epochs
|
||||
sequence_length = args.sequence_length
|
||||
lr_decay_factor = args.lr_decay_factor
|
||||
split_mode = args.split_mode
|
||||
# not exposed
|
||||
adam_betas = [0.8, 0.95]
|
||||
adam_eps = 1e-8
|
||||
|
||||
checkpoint['batch_size'] = batch_size
|
||||
checkpoint['lr'] = lr
|
||||
checkpoint['lr_decay_factor'] = lr_decay_factor
|
||||
checkpoint['split_mode'] = split_mode
|
||||
checkpoint['epochs'] = epochs
|
||||
checkpoint['sequence_length'] = sequence_length
|
||||
checkpoint['adam_betas'] = adam_betas
|
||||
|
||||
# logging
|
||||
log_interval = 10
|
||||
|
||||
# device
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
# model parameters
|
||||
cond_size = args.cond_size
|
||||
cond_size2 = args.cond_size2
|
||||
latent_dim = args.latent_dim
|
||||
quant_levels = args.quant_levels
|
||||
lambda_min = args.lambda_min
|
||||
lambda_max = args.lambda_max
|
||||
state_dim = args.state_dim
|
||||
softquant = args.softquant
|
||||
# not expsed
|
||||
num_features = 20
|
||||
|
||||
|
||||
# training data
|
||||
feature_file = args.features
|
||||
|
||||
# model
|
||||
checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
|
||||
checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate, 'softquant': softquant, 'chunks_per_offset': args.chunks_per_offset}
|
||||
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
|
||||
|
||||
if type(args.initial_checkpoint) != type(None):
|
||||
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
|
||||
if args.train_decoder_only:
|
||||
if args.initial_checkpoint is None:
|
||||
print("warning: training decoder only without providing initial checkpoint")
|
||||
|
||||
for p in model.core_encoder.module.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
for p in model.statistical_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# dataloader
|
||||
checkpoint['dataset_args'] = (feature_file, sequence_length, num_features, 36)
|
||||
checkpoint['dataset_kwargs'] = {'lambda_min': lambda_min, 'lambda_max': lambda_max, 'enc_stride': model.enc_stride, 'quant_levels': quant_levels}
|
||||
dataset = RDOVAEDataset(*checkpoint['dataset_args'], **checkpoint['dataset_kwargs'])
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
|
||||
|
||||
|
||||
|
||||
# optimizer
|
||||
params = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.Adam(params, lr=lr, betas=adam_betas, eps=adam_eps)
|
||||
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# push model to device
|
||||
model.to(device)
|
||||
|
||||
# training loop
|
||||
|
||||
batch = 1
|
||||
for epoch in range(1, epochs + 1):
|
||||
|
||||
print(f"training epoch {epoch}...")
|
||||
|
||||
# running stats
|
||||
running_rate_loss = 0
|
||||
running_soft_dist_loss = 0
|
||||
running_hard_dist_loss = 0
|
||||
running_hard_rate_loss = 0
|
||||
running_soft_rate_loss = 0
|
||||
running_total_loss = 0
|
||||
running_rate_metric = 0
|
||||
running_states_rate_metric = 0
|
||||
previous_total_loss = 0
|
||||
running_first_frame_loss = 0
|
||||
|
||||
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
|
||||
for i, (features, rate_lambda, q_ids) in enumerate(tepoch):
|
||||
|
||||
# zero out gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# push inputs to device
|
||||
features = features.to(device)
|
||||
q_ids = q_ids.to(device)
|
||||
rate_lambda = rate_lambda.to(device)
|
||||
|
||||
|
||||
rate_lambda_upsamp = torch.repeat_interleave(rate_lambda, 2, 1)
|
||||
|
||||
# run model
|
||||
model_output = model(features, q_ids)
|
||||
|
||||
# collect outputs
|
||||
z = model_output['z']
|
||||
states = model_output['states']
|
||||
outputs_hard_quant = model_output['outputs_hard_quant']
|
||||
outputs_soft_quant = model_output['outputs_soft_quant']
|
||||
statistical_model = model_output['statistical_model']
|
||||
|
||||
if type(args.initial_checkpoint) == type(None):
|
||||
latent_lambda = (1. - .5/(1.+batch/1000))
|
||||
state_lambda = (1. - .9/(1.+batch/6000))
|
||||
else:
|
||||
latent_lambda = 1.
|
||||
state_lambda = 1.
|
||||
|
||||
# rate loss
|
||||
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
|
||||
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
|
||||
states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
|
||||
states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
|
||||
soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*soft_rate + .04*state_lambda*states_soft_rate))
|
||||
hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*hard_rate + .04*state_lambda*states_hard_rate))
|
||||
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
|
||||
hard_rate_metric = torch.mean(hard_rate)
|
||||
states_rate_metric = torch.mean(states_hard_rate)
|
||||
|
||||
## distortion losses
|
||||
|
||||
# hard quantized decoder input
|
||||
distortion_loss_hard_quant = torch.zeros_like(rate_loss)
|
||||
for dec_features, start, stop in outputs_hard_quant:
|
||||
distortion_loss_hard_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_hard_quant)
|
||||
|
||||
first_frame_loss = torch.zeros_like(rate_loss)
|
||||
for dec_features, start, stop in outputs_hard_quant:
|
||||
first_frame_loss += distortion_loss(features[..., stop-4 : stop, :], dec_features[..., -4:, :], rate_lambda_upsamp[..., stop - 4 : stop]) / len(outputs_hard_quant)
|
||||
|
||||
# soft quantized decoder input
|
||||
distortion_loss_soft_quant = torch.zeros_like(rate_loss)
|
||||
for dec_features, start, stop in outputs_soft_quant:
|
||||
distortion_loss_soft_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_soft_quant)
|
||||
|
||||
# total loss
|
||||
total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2
|
||||
|
||||
if args.enable_first_frame_loss:
|
||||
total_loss = .97*total_loss + 0.03 * first_frame_loss
|
||||
|
||||
|
||||
total_loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
model.clip_weights()
|
||||
model.sparsify()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
# collect running stats
|
||||
running_hard_dist_loss += float(distortion_loss_hard_quant.detach().cpu())
|
||||
running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu())
|
||||
running_rate_loss += float(rate_loss.detach().cpu())
|
||||
running_rate_metric += float(hard_rate_metric.detach().cpu())
|
||||
running_states_rate_metric += float(states_rate_metric.detach().cpu())
|
||||
running_total_loss += float(total_loss.detach().cpu())
|
||||
running_first_frame_loss += float(first_frame_loss.detach().cpu())
|
||||
running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
|
||||
running_hard_rate_loss += float(hard_rate_loss.detach().cpu())
|
||||
|
||||
if (i + 1) % log_interval == 0:
|
||||
current_loss = (running_total_loss - previous_total_loss) / log_interval
|
||||
tepoch.set_postfix(
|
||||
current_loss=current_loss,
|
||||
total_loss=running_total_loss / (i + 1),
|
||||
dist_hq=running_hard_dist_loss / (i + 1),
|
||||
dist_sq=running_soft_dist_loss / (i + 1),
|
||||
rate_loss=running_rate_loss / (i + 1),
|
||||
rate=running_rate_metric / (i + 1),
|
||||
states_rate=running_states_rate_metric / (i + 1),
|
||||
ffloss=running_first_frame_loss / (i + 1),
|
||||
rateloss_hard=running_hard_rate_loss / (i + 1),
|
||||
rateloss_soft=running_soft_rate_loss / (i + 1)
|
||||
)
|
||||
previous_total_loss = running_total_loss
|
||||
batch = batch+1
|
||||
|
||||
# save checkpoint
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['loss'] = running_total_loss / len(dataloader)
|
||||
checkpoint['epoch'] = epoch
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
Reference in New Issue
Block a user