add some code
This commit is contained in:
@@ -0,0 +1,123 @@
|
||||
import argparse
|
||||
import os
|
||||
import yaml
|
||||
import subprocess
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('commonvoice_base_dir')
|
||||
parser.add_argument('output_dir')
|
||||
parser.add_argument('--clips-per-language', required=False, type=int, default=10)
|
||||
parser.add_argument('--seed', required=False, type=int, default=2024)
|
||||
|
||||
|
||||
def select_clips(dir, num_clips=10):
|
||||
|
||||
if num_clips % 2:
|
||||
print(f"warning: number of clips will be reduced to {num_clips - 1}")
|
||||
female = dict()
|
||||
male = dict()
|
||||
|
||||
clips = np.genfromtxt(os.path.join(dir, 'validated.tsv'), delimiter='\t', dtype=str, invalid_raise=False)
|
||||
clips_by_client = dict()
|
||||
|
||||
if len(clips.shape) < 2 or len(clips) < num_clips:
|
||||
# not enough data to proceed
|
||||
return None
|
||||
|
||||
for client in set(clips[1:,0]):
|
||||
client_clips = clips[clips[:, 0] == client]
|
||||
f, m = False, False
|
||||
if 'female_feminine' in client_clips[:, 8]:
|
||||
female[client] = client_clips[client_clips[:, 8] == 'female_feminine']
|
||||
f = True
|
||||
if 'male_masculine' in client_clips[:, 8]:
|
||||
male[client] = client_clips[client_clips[:, 8] == 'male_masculine']
|
||||
m = True
|
||||
|
||||
if f and m:
|
||||
print(f"both male and female clips under client {client}")
|
||||
|
||||
|
||||
if min(len(female), len(male)) < num_clips // 2:
|
||||
return None
|
||||
|
||||
# select num_clips // 2 random female clients
|
||||
female_client_selection = np.array(list(female.keys()), dtype=str)[np.random.choice(len(female), num_clips//2, replace=False)]
|
||||
female_clip_selection = []
|
||||
for c in female_client_selection:
|
||||
s_idx = np.random.randint(0, len(female[c]))
|
||||
female_clip_selection.append(os.path.join(dir, 'clips', female[c][s_idx, 1].item()))
|
||||
|
||||
# select num_clips // 2 random female clients
|
||||
male_client_selection = np.array(list(male.keys()), dtype=str)[np.random.choice(len(male), num_clips//2, replace=False)]
|
||||
male_clip_selection = []
|
||||
for c in male_client_selection:
|
||||
s_idx = np.random.randint(0, len(male[c]))
|
||||
male_clip_selection.append(os.path.join(dir, 'clips', male[c][s_idx, 1].item()))
|
||||
|
||||
return female_clip_selection + male_clip_selection
|
||||
|
||||
def ffmpeg_available():
|
||||
try:
|
||||
x = subprocess.run(['ffmpeg', '-h'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
return x.returncode == 0
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def convert_clips(selection, outdir):
|
||||
if not ffmpeg_available():
|
||||
raise RuntimeError("ffmpeg not available")
|
||||
|
||||
clipdir = os.path.join(outdir, 'clips')
|
||||
os.makedirs(clipdir, exist_ok=True)
|
||||
|
||||
clipdict = dict()
|
||||
|
||||
for lang, clips in selection.items():
|
||||
clipdict[lang] = []
|
||||
for clip in clips:
|
||||
clipname = os.path.splitext(os.path.split(clip)[-1])[0]
|
||||
target_name = os.path.join('clips', clipname + '.wav')
|
||||
call_args = ['ffmpeg', '-i', clip, '-ar', '16000', os.path.join(outdir, target_name)]
|
||||
print(call_args)
|
||||
r = subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
if r.returncode != 0:
|
||||
raise RuntimeError(f'could not execute {call_args}')
|
||||
clipdict[lang].append(target_name)
|
||||
|
||||
return clipdict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not ffmpeg_available():
|
||||
raise RuntimeError("ffmpeg not available")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
base_dir = args.commonvoice_base_dir
|
||||
output_dir = args.output_dir
|
||||
seed = args.seed
|
||||
|
||||
np.random.seed(seed)
|
||||
|
||||
langs = os.listdir(base_dir)
|
||||
selection = dict()
|
||||
|
||||
for lang in langs:
|
||||
print(f"processing {lang}...")
|
||||
clips = select_clips(os.path.join(base_dir, lang))
|
||||
if clips is not None:
|
||||
selection[lang] = clips
|
||||
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
clips = convert_clips(selection, output_dir)
|
||||
|
||||
with open(os.path.join(output_dir, 'clips.yml'), 'w') as f:
|
||||
yaml.dump(clips, f)
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
INPUT="dataset/LibriSpeech"
|
||||
OUTPUT="testdata"
|
||||
OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
|
||||
BITRATES=( 6000 7500 ) # 9000 12000 15000 18000 24000 32000 )
|
||||
|
||||
|
||||
mkdir -p $OUTPUT
|
||||
|
||||
for fn in $(find $INPUT -name "*.wav")
|
||||
do
|
||||
name=$(basename ${fn%*.wav})
|
||||
sox $fn -r 16000 -b 16 -e signed-integer ${OUTPUT}/tmp.raw
|
||||
for br in ${BITRATES[@]}
|
||||
do
|
||||
folder=${OUTPUT}/"${name}_${br}.se"
|
||||
echo "creating ${folder}..."
|
||||
mkdir -p $folder
|
||||
cp ${OUTPUT}/tmp.raw ${folder}/clean.s16
|
||||
(cd ${folder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
|
||||
done
|
||||
rm -f ${OUTPUT}/tmp.raw
|
||||
done
|
||||
@@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
export PYTHON=/home/ubuntu/opt/miniconda3/envs/torch/bin/python
|
||||
export LACE="/local/experiments/ietf_enhancement_studies/checkpoints/lace_checkpoint.pth"
|
||||
export NOLACE="/local/experiments/ietf_enhancement_studies/checkpoints/nolace_checkpoint.pth"
|
||||
export TESTMODEL="/local/experiments/ietf_enhancement_studies/opus/dnn/torch/osce/test_model.py"
|
||||
export OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
|
||||
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
from scipy.io import wavfile
|
||||
from pesq import pesq
|
||||
import numpy as np
|
||||
from moc import compare
|
||||
from moc2 import compare as compare2
|
||||
#from warpq import compute_WAPRQ as warpq
|
||||
from lace_loss_metric import compare as laceloss_compare
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='folder with processed items')
|
||||
parser.add_argument('metric', type=str, choices=['pesq', 'moc', 'moc2', 'laceloss'], help='metric to be used for evaluation')
|
||||
|
||||
|
||||
def get_bitrates(folder):
|
||||
with open(os.path.join(folder, 'bitrates.txt')) as f:
|
||||
x = f.read()
|
||||
|
||||
bitrates = [int(y) for y in x.rstrip('\n').split()]
|
||||
|
||||
return bitrates
|
||||
|
||||
def get_itemlist(folder):
|
||||
with open(os.path.join(folder, 'items.txt')) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
items = [x.split()[0] for x in lines]
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def process_item(folder, item, bitrate, metric):
|
||||
fs, x_clean = wavfile.read(os.path.join(folder, 'clean', f"{item}_{bitrate}_clean.wav"))
|
||||
fs, x_opus = wavfile.read(os.path.join(folder, 'opus', f"{item}_{bitrate}_opus.wav"))
|
||||
fs, x_lace = wavfile.read(os.path.join(folder, 'lace', f"{item}_{bitrate}_lace.wav"))
|
||||
fs, x_nolace = wavfile.read(os.path.join(folder, 'nolace', f"{item}_{bitrate}_nolace.wav"))
|
||||
|
||||
x_clean = x_clean.astype(np.float32) / 2**15
|
||||
x_opus = x_opus.astype(np.float32) / 2**15
|
||||
x_lace = x_lace.astype(np.float32) / 2**15
|
||||
x_nolace = x_nolace.astype(np.float32) / 2**15
|
||||
|
||||
if metric == 'pesq':
|
||||
result = [pesq(fs, x_clean, x_opus), pesq(fs, x_clean, x_lace), pesq(fs, x_clean, x_nolace)]
|
||||
elif metric =='moc':
|
||||
result = [compare(x_clean, x_opus), compare(x_clean, x_lace), compare(x_clean, x_nolace)]
|
||||
elif metric =='moc2':
|
||||
result = [compare2(x_clean, x_opus), compare2(x_clean, x_lace), compare2(x_clean, x_nolace)]
|
||||
# elif metric == 'warpq':
|
||||
# result = [warpq(x_clean, x_opus), warpq(x_clean, x_lace), warpq(x_clean, x_nolace)]
|
||||
elif metric == 'laceloss':
|
||||
result = [laceloss_compare(x_clean, x_opus), laceloss_compare(x_clean, x_lace), laceloss_compare(x_clean, x_nolace)]
|
||||
else:
|
||||
raise ValueError(f'unknown metric {metric}')
|
||||
|
||||
return result
|
||||
|
||||
def process_bitrate(folder, items, bitrate, metric):
|
||||
results = np.zeros((len(items), 3))
|
||||
|
||||
for i, item in enumerate(items):
|
||||
results[i, :] = np.array(process_item(folder, item, bitrate, metric))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
items = get_itemlist(args.folder)
|
||||
bitrates = get_bitrates(args.folder)
|
||||
|
||||
results = dict()
|
||||
for br in bitrates:
|
||||
print(f"processing bitrate {br}...")
|
||||
results[br] = process_bitrate(args.folder, items, br, args.metric)
|
||||
|
||||
np.save(os.path.join(args.folder, f'results_{args.metric}.npy'), results)
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
"""STFT-based Loss modules."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
|
||||
|
||||
def get_window(win_name, win_length, *args, **kwargs):
|
||||
window_dict = {
|
||||
'bartlett_window' : torch.bartlett_window,
|
||||
'blackman_window' : torch.blackman_window,
|
||||
'hamming_window' : torch.hamming_window,
|
||||
'hann_window' : torch.hann_window,
|
||||
'kaiser_window' : torch.kaiser_window
|
||||
}
|
||||
|
||||
if not win_name in window_dict:
|
||||
raise ValueError()
|
||||
|
||||
return window_dict[win_name](win_length, *args, **kwargs)
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
"""
|
||||
|
||||
win = get_window(window, win_length).to(x.device)
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
|
||||
|
||||
|
||||
return torch.clamp(torch.abs(x_stft), min=1e-7)
|
||||
|
||||
def spectral_convergence_loss(Y_true, Y_pred):
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
|
||||
|
||||
|
||||
def log_magnitude_loss(Y_true, Y_pred):
|
||||
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
|
||||
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
|
||||
|
||||
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
|
||||
|
||||
def spectral_xcorr_loss(Y_true, Y_pred):
|
||||
Y_true = Y_true.abs()
|
||||
Y_pred = Y_pred.abs()
|
||||
dims=list(range(1, len(Y_pred.shape)))
|
||||
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
|
||||
|
||||
return 1 - xcorr.mean()
|
||||
|
||||
|
||||
|
||||
class MRLogMelLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
fs=16000,
|
||||
n_mels=18
|
||||
):
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.fs = fs
|
||||
self.n_mels = n_mels
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.mel_specs = []
|
||||
for fft_size in fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
|
||||
n_mels = self.n_mels
|
||||
if fft_size < 128:
|
||||
n_mels //= 2
|
||||
|
||||
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
|
||||
|
||||
for i, mel_spec in enumerate(self.mel_specs):
|
||||
self.add_module(f'mel_spec_{i+1}', mel_spec)
|
||||
|
||||
def forward(self, y_true, y_pred):
|
||||
|
||||
loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for mel_spec in self.mel_specs:
|
||||
Y_true = mel_spec(y_true)
|
||||
Y_pred = mel_spec(y_pred)
|
||||
loss = loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
loss = loss / len(self.mel_specs)
|
||||
|
||||
return loss
|
||||
|
||||
def create_weight_matrix(num_bins, bins_per_band=10):
|
||||
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
|
||||
|
||||
r0 = bins_per_band // 2
|
||||
r1 = bins_per_band - r0
|
||||
|
||||
for i in range(num_bins):
|
||||
i0 = max(i - r0, 0)
|
||||
j0 = min(i + r1, num_bins)
|
||||
|
||||
m[i, i0: j0] += 1
|
||||
|
||||
if i < r0:
|
||||
m[i, :r0 - i] += 1
|
||||
|
||||
if i > num_bins - r1:
|
||||
m[i, num_bins - r1 - i:] += 1
|
||||
|
||||
return m / bins_per_band
|
||||
|
||||
def weighted_spectral_convergence(Y_true, Y_pred, w):
|
||||
|
||||
# calculate sfm based weights
|
||||
logY = torch.log(torch.abs(Y_true) + 1e-9)
|
||||
Y = torch.abs(Y_true)
|
||||
|
||||
avg_logY = torch.matmul(logY.transpose(1, 2), w)
|
||||
avg_Y = torch.matmul(Y.transpose(1, 2), w)
|
||||
|
||||
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
|
||||
|
||||
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
|
||||
|
||||
loss = torch.mean(
|
||||
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
|
||||
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def gen_filterbank(N, Fs=16000):
|
||||
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
|
||||
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
|
||||
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
|
||||
ERB_N = 24.7 + .108*in_freq
|
||||
delta = np.abs(in_freq-out_freq)/ERB_N
|
||||
center = (delta<.5).astype('float32')
|
||||
R = -12*center*delta**2 + (1-center)*(3-12*delta)
|
||||
RE = 10.**(R/10.)
|
||||
norm = np.sum(RE, axis=1)
|
||||
RE = RE/norm[:, np.newaxis]
|
||||
return torch.from_numpy(RE)
|
||||
|
||||
def smooth_log_mag(Y_true, Y_pred, filterbank):
|
||||
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
|
||||
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
|
||||
|
||||
loss = torch.abs(
|
||||
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
|
||||
)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
class MRSTFTLoss(nn.Module):
|
||||
def __init__(self,
|
||||
fft_sizes=[2048, 1024, 512, 256, 128, 64],
|
||||
overlap=0.5,
|
||||
window='hann_window',
|
||||
fs=16000,
|
||||
log_mag_weight=0,
|
||||
sc_weight=0,
|
||||
wsc_weight=0,
|
||||
smooth_log_mag_weight=2,
|
||||
sxcorr_weight=1):
|
||||
super().__init__()
|
||||
|
||||
self.fft_sizes = fft_sizes
|
||||
self.overlap = overlap
|
||||
self.window = window
|
||||
self.log_mag_weight = log_mag_weight
|
||||
self.sc_weight = sc_weight
|
||||
self.wsc_weight = wsc_weight
|
||||
self.smooth_log_mag_weight = smooth_log_mag_weight
|
||||
self.sxcorr_weight = sxcorr_weight
|
||||
self.fs = fs
|
||||
|
||||
# weights for SFM weighted spectral convergence loss
|
||||
self.wsc_weights = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
width = min(11, int(1000 * fft_size / self.fs + .5))
|
||||
width += width % 2
|
||||
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
|
||||
create_weight_matrix(fft_size // 2 + 1, width),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
# filterbanks for smooth log magnitude loss
|
||||
self.filterbanks = torch.nn.ParameterDict()
|
||||
for fft_size in fft_sizes:
|
||||
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
|
||||
gen_filterbank(fft_size//2),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def __call__(self, y_true, y_pred):
|
||||
|
||||
|
||||
lm_loss = torch.zeros(1, device=y_true.device)
|
||||
sc_loss = torch.zeros(1, device=y_true.device)
|
||||
wsc_loss = torch.zeros(1, device=y_true.device)
|
||||
slm_loss = torch.zeros(1, device=y_true.device)
|
||||
sxcorr_loss = torch.zeros(1, device=y_true.device)
|
||||
|
||||
for fft_size in self.fft_sizes:
|
||||
hop_size = int(round(fft_size * (1 - self.overlap)))
|
||||
win_size = fft_size
|
||||
|
||||
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
|
||||
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
|
||||
|
||||
if self.log_mag_weight > 0:
|
||||
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
|
||||
|
||||
if self.sc_weight > 0:
|
||||
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
|
||||
|
||||
if self.wsc_weight > 0:
|
||||
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
|
||||
|
||||
if self.smooth_log_mag_weight > 0:
|
||||
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
|
||||
|
||||
if self.sxcorr_weight > 0:
|
||||
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
|
||||
|
||||
|
||||
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
|
||||
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
|
||||
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def td_l2_norm(y_true, y_pred):
|
||||
dims = list(range(1, len(y_true.shape)))
|
||||
|
||||
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class LaceLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.stftloss = MRSTFTLoss(log_mag_weight=0, sc_weight=0, wsc_weight=0, smooth_log_mag_weight=2, sxcorr_weight=1)
|
||||
|
||||
|
||||
def forward(self, x, y):
|
||||
specloss = self.stftloss(x, y)
|
||||
phaseloss = td_l2_norm(x, y)
|
||||
total_loss = (specloss + 10 * phaseloss) / 13
|
||||
|
||||
return total_loss
|
||||
|
||||
def compare(self, x_ref, x_deg):
|
||||
# trim items to same size
|
||||
n = min(len(x_ref), len(x_deg))
|
||||
x_ref = x_ref[:n].copy()
|
||||
x_deg = x_deg[:n].copy()
|
||||
|
||||
# pre-emphasis
|
||||
x_ref[1:] -= 0.85 * x_ref[:-1]
|
||||
x_deg[1:] -= 0.85 * x_deg[:-1]
|
||||
|
||||
device = next(iter(self.parameters())).device
|
||||
|
||||
x = torch.from_numpy(x_ref).to(device)
|
||||
y = torch.from_numpy(x_deg).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
dist = 10 * self.forward(x, y)
|
||||
|
||||
return dist.cpu().numpy().item()
|
||||
|
||||
|
||||
lace_loss = LaceLoss()
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
lace_loss.to(device)
|
||||
|
||||
def compare(x, y):
|
||||
|
||||
return lace_loss.compare(x, y)
|
||||
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
|
||||
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def plot_data(filename, data, title=None):
|
||||
compare_dict = dict()
|
||||
for br in data.keys():
|
||||
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
|
||||
compare_dict[f'LACE {br/1000:.1f} kb/s'] = data[br][:, 1]
|
||||
compare_dict[f'NoLACE {br/1000:.1f} kb/s'] = data[br][:, 2]
|
||||
|
||||
plt.rcParams.update({
|
||||
"text.usetex": True,
|
||||
"font.family": "Helvetica",
|
||||
"font.size": 32
|
||||
})
|
||||
|
||||
black = '#000000'
|
||||
red = '#ff5745'
|
||||
blue = '#007dbc'
|
||||
colors = [black, red, blue]
|
||||
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
|
||||
Patch(facecolor=colors[1], label='LACE'),
|
||||
Patch(facecolor=colors[2], label='NoLACE')]
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(40, 20)
|
||||
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
|
||||
|
||||
for i, patch in enumerate(bplot['boxes']):
|
||||
patch.set_facecolor(colors[i%3])
|
||||
|
||||
ax.set_xticklabels(compare_dict.keys(), rotation=290)
|
||||
|
||||
if title is not None:
|
||||
ax.set_title(title)
|
||||
|
||||
ax.legend(handles=legend_elements)
|
||||
|
||||
fig.savefig(filename, bbox_inches='tight')
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def plot_data(filename, data, title=None):
|
||||
compare_dict = dict()
|
||||
for br in data.keys():
|
||||
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
|
||||
compare_dict[f'LACE (MOC only) {br/1000:.1f} kb/s'] = data[br][:, 1]
|
||||
compare_dict[f'LACE (MOC + TD) {br/1000:.1f} kb/s'] = data[br][:, 2]
|
||||
|
||||
plt.rcParams.update({
|
||||
"text.usetex": True,
|
||||
"font.family": "Helvetica",
|
||||
"font.size": 32
|
||||
})
|
||||
colors = ['pink', 'lightblue', 'lightgreen']
|
||||
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
|
||||
Patch(facecolor=colors[1], label='MOC loss only'),
|
||||
Patch(facecolor=colors[2], label='MOC + TD loss')]
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(40, 20)
|
||||
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
|
||||
|
||||
for i, patch in enumerate(bplot['boxes']):
|
||||
patch.set_facecolor(colors[i%3])
|
||||
|
||||
ax.set_xticklabels(compare_dict.keys(), rotation=290)
|
||||
|
||||
if title is not None:
|
||||
ax.set_title(title)
|
||||
|
||||
ax.legend(handles=legend_elements)
|
||||
|
||||
fig.savefig(filename, bbox_inches='tight')
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
|
||||
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def make_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1]
|
||||
nolace = data[br][:, 2]
|
||||
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
|
||||
def make_diff_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1] - opus
|
||||
nolace = data[br][:, 2] - opus
|
||||
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
|
||||
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from prettytable import PrettyTable
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
|
||||
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
|
||||
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
|
||||
|
||||
def load_data(folder):
|
||||
data = dict()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
|
||||
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
|
||||
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
|
||||
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
|
||||
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
|
||||
|
||||
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
|
||||
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
|
||||
|
||||
return data
|
||||
|
||||
def make_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1]
|
||||
nolace = data[br][:, 2]
|
||||
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
|
||||
def make_diff_table(filename, data, title=None):
|
||||
|
||||
# mean values
|
||||
tbl = PrettyTable()
|
||||
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
|
||||
for br in data.keys():
|
||||
opus = data[br][:, 0]
|
||||
lace = data[br][:, 1] - opus
|
||||
nolace = data[br][:, 2] - opus
|
||||
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
|
||||
|
||||
with open(filename + ".txt", "w") as f:
|
||||
f.write(str(tbl))
|
||||
|
||||
with open(filename + ".html", "w") as f:
|
||||
f.write(tbl.get_html_string())
|
||||
|
||||
with open(filename + ".csv", "w") as f:
|
||||
f.write(tbl.get_csv_string())
|
||||
|
||||
print(tbl)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
data = load_data(args.folder)
|
||||
|
||||
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
|
||||
folder = args.folder if args.output is None else args.output
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
for metric in metrics:
|
||||
print(f"Plotting data for {metric} metric...")
|
||||
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
|
||||
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
|
||||
def compute_vad_mask(x, fs, stop_db=-70):
|
||||
|
||||
frame_length = (fs + 49) // 50
|
||||
x = x[: frame_length * (len(x) // frame_length)]
|
||||
|
||||
frames = x.reshape(-1, frame_length)
|
||||
frame_energy = np.sum(frames ** 2, axis=1)
|
||||
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
|
||||
|
||||
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
|
||||
vactive = np.ones_like(frames)
|
||||
vactive[frame_energy_smooth < max_threshold, :] = 0
|
||||
vactive = vactive.reshape(-1)
|
||||
|
||||
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
|
||||
filter = filter / filter.sum()
|
||||
|
||||
mask = np.convolve(vactive, filter, mode='same')
|
||||
|
||||
return x, mask
|
||||
|
||||
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
|
||||
num_samples = frame_size + (num_frames - 1) * hop_size
|
||||
if len(mask) < num_samples:
|
||||
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
|
||||
else:
|
||||
mask = mask[:num_samples]
|
||||
|
||||
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
|
||||
|
||||
return new_mask
|
||||
|
||||
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
|
||||
num_spectra = (len(x) - window_size - hop_size) // hop_size
|
||||
window = scipy.signal.get_window(window, window_size)
|
||||
N = window_size // 2
|
||||
|
||||
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
|
||||
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
|
||||
|
||||
return psd
|
||||
|
||||
|
||||
def frequency_mask(num_bands, up_factor, down_factor):
|
||||
|
||||
up_mask = np.zeros((num_bands, num_bands))
|
||||
down_mask = np.zeros((num_bands, num_bands))
|
||||
|
||||
for i in range(num_bands):
|
||||
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
|
||||
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
|
||||
|
||||
return down_mask @ up_mask
|
||||
|
||||
|
||||
def rect_fb(band_limits, num_bins=None):
|
||||
num_bands = len(band_limits) - 1
|
||||
if num_bins is None:
|
||||
num_bins = band_limits[-1]
|
||||
|
||||
fb = np.zeros((num_bands, num_bins))
|
||||
for i in range(num_bands):
|
||||
fb[i, band_limits[i]:band_limits[i+1]] = 1
|
||||
|
||||
return fb
|
||||
|
||||
|
||||
def compare(x, y, apply_vad=False):
|
||||
""" Modified version of opus_compare for 16 kHz mono signals
|
||||
|
||||
Args:
|
||||
x (np.ndarray): reference input signal scaled to [-1, 1]
|
||||
y (np.ndarray): test signal scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
float: perceptually weighted error
|
||||
"""
|
||||
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
|
||||
band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
|
||||
num_bands = len(band_limits) - 1
|
||||
fb = rect_fb(band_limits, num_bins=81)
|
||||
|
||||
# trim samples to same size
|
||||
num_samples = min(len(x), len(y))
|
||||
x = x[:num_samples] * 2**15
|
||||
y = y[:num_samples] * 2**15
|
||||
|
||||
psd_x = power_spectrum(x) + 100000
|
||||
psd_y = power_spectrum(y) + 100000
|
||||
|
||||
num_frames = psd_x.shape[0]
|
||||
|
||||
# average band energies
|
||||
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
|
||||
|
||||
# frequecy masking
|
||||
f_mask = frequency_mask(num_bands, 0.1, 0.03)
|
||||
mask_x = be_x @ f_mask.T
|
||||
|
||||
# temporal masking
|
||||
for i in range(1, num_frames):
|
||||
mask_x[i, :] += 0.5 * mask_x[i-1, :]
|
||||
|
||||
# apply mask
|
||||
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
|
||||
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
|
||||
|
||||
# 2-frame average
|
||||
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
|
||||
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
|
||||
|
||||
# distortion metric
|
||||
re = masked_psd_y / masked_psd_x
|
||||
im = np.log(re) ** 2
|
||||
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
|
||||
Ef = np.mean(Eb , axis=1)
|
||||
|
||||
if apply_vad:
|
||||
_, mask = compute_vad_mask(x, 16000)
|
||||
mask = convert_mask(mask, Ef.shape[0])
|
||||
else:
|
||||
mask = np.ones_like(Ef)
|
||||
|
||||
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
|
||||
|
||||
return float(err)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from scipy.io import wavfile
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('ref', type=str, help='reference wav file')
|
||||
parser.add_argument('deg', type=str, help='degraded wav file')
|
||||
parser.add_argument('--apply-vad', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
fs1, x = wavfile.read(args.ref)
|
||||
fs2, y = wavfile.read(args.deg)
|
||||
|
||||
if max(fs1, fs2) != 16000:
|
||||
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
|
||||
|
||||
x = x.astype(np.float32) / 2**15
|
||||
y = y.astype(np.float32) / 2**15
|
||||
|
||||
err = compare(x, y, apply_vad=args.apply_vad)
|
||||
|
||||
print(f"MOC: {err}")
|
||||
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
|
||||
def compute_vad_mask(x, fs, stop_db=-70):
|
||||
|
||||
frame_length = (fs + 49) // 50
|
||||
x = x[: frame_length * (len(x) // frame_length)]
|
||||
|
||||
frames = x.reshape(-1, frame_length)
|
||||
frame_energy = np.sum(frames ** 2, axis=1)
|
||||
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
|
||||
|
||||
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
|
||||
vactive = np.ones_like(frames)
|
||||
vactive[frame_energy_smooth < max_threshold, :] = 0
|
||||
vactive = vactive.reshape(-1)
|
||||
|
||||
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
|
||||
filter = filter / filter.sum()
|
||||
|
||||
mask = np.convolve(vactive, filter, mode='same')
|
||||
|
||||
return x, mask
|
||||
|
||||
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
|
||||
num_samples = frame_size + (num_frames - 1) * hop_size
|
||||
if len(mask) < num_samples:
|
||||
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
|
||||
else:
|
||||
mask = mask[:num_samples]
|
||||
|
||||
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
|
||||
|
||||
return new_mask
|
||||
|
||||
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
|
||||
num_spectra = (len(x) - window_size - hop_size) // hop_size
|
||||
window = scipy.signal.get_window(window, window_size)
|
||||
N = window_size // 2
|
||||
|
||||
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
|
||||
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
|
||||
|
||||
return psd
|
||||
|
||||
|
||||
def frequency_mask(num_bands, up_factor, down_factor):
|
||||
|
||||
up_mask = np.zeros((num_bands, num_bands))
|
||||
down_mask = np.zeros((num_bands, num_bands))
|
||||
|
||||
for i in range(num_bands):
|
||||
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
|
||||
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
|
||||
|
||||
return down_mask @ up_mask
|
||||
|
||||
|
||||
def rect_fb(band_limits, num_bins=None):
|
||||
num_bands = len(band_limits) - 1
|
||||
if num_bins is None:
|
||||
num_bins = band_limits[-1]
|
||||
|
||||
fb = np.zeros((num_bands, num_bins))
|
||||
for i in range(num_bands):
|
||||
fb[i, band_limits[i]:band_limits[i+1]] = 1
|
||||
|
||||
return fb
|
||||
|
||||
|
||||
def _compare(x, y, apply_vad=False, factor=1):
|
||||
""" Modified version of opus_compare for 16 kHz mono signals
|
||||
|
||||
Args:
|
||||
x (np.ndarray): reference input signal scaled to [-1, 1]
|
||||
y (np.ndarray): test signal scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
float: perceptually weighted error
|
||||
"""
|
||||
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
|
||||
band_limits = [factor * b for b in [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]]
|
||||
window_size = factor * 160
|
||||
hop_size = factor * 40
|
||||
num_bins = window_size // 2 + 1
|
||||
num_bands = len(band_limits) - 1
|
||||
fb = rect_fb(band_limits, num_bins=num_bins)
|
||||
|
||||
# trim samples to same size
|
||||
num_samples = min(len(x), len(y))
|
||||
x = x[:num_samples].copy() * 2**15
|
||||
y = y[:num_samples].copy() * 2**15
|
||||
|
||||
psd_x = power_spectrum(x, window_size=window_size, hop_size=hop_size) + 100000
|
||||
psd_y = power_spectrum(y, window_size=window_size, hop_size=hop_size) + 100000
|
||||
|
||||
num_frames = psd_x.shape[0]
|
||||
|
||||
# average band energies
|
||||
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
|
||||
|
||||
# frequecy masking
|
||||
f_mask = frequency_mask(num_bands, 0.1, 0.03)
|
||||
mask_x = be_x @ f_mask.T
|
||||
|
||||
# temporal masking
|
||||
for i in range(1, num_frames):
|
||||
mask_x[i, :] += (0.5 ** factor) * mask_x[i-1, :]
|
||||
|
||||
# apply mask
|
||||
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
|
||||
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
|
||||
|
||||
# 2-frame average
|
||||
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
|
||||
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
|
||||
|
||||
# distortion metric
|
||||
re = masked_psd_y / masked_psd_x
|
||||
#im = re - np.log(re) - 1
|
||||
im = np.log(re) ** 2
|
||||
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
|
||||
Ef = np.mean(Eb ** 1, axis=1)
|
||||
|
||||
if apply_vad:
|
||||
_, mask = compute_vad_mask(x, 16000)
|
||||
mask = convert_mask(mask, Ef.shape[0])
|
||||
else:
|
||||
mask = np.ones_like(Ef)
|
||||
|
||||
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
|
||||
|
||||
return float(err)
|
||||
|
||||
def compare(x, y, apply_vad=False):
|
||||
err = np.linalg.norm([_compare(x, y, apply_vad=apply_vad, factor=1)], ord=2)
|
||||
return err
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from scipy.io import wavfile
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('ref', type=str, help='reference wav file')
|
||||
parser.add_argument('deg', type=str, help='degraded wav file')
|
||||
parser.add_argument('--apply-vad', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
fs1, x = wavfile.read(args.ref)
|
||||
fs2, y = wavfile.read(args.deg)
|
||||
|
||||
if max(fs1, fs2) != 16000:
|
||||
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
|
||||
|
||||
x = x.astype(np.float32) / 2**15
|
||||
y = y.astype(np.float32) / 2**15
|
||||
|
||||
err = compare(x, y, apply_vad=args.apply_vad)
|
||||
|
||||
print(f"MOC: {err}")
|
||||
@@ -0,0 +1,98 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ ! -f "$PYTHON" ]
|
||||
then
|
||||
echo "PYTHON variable does not link to a file. Please point it to your python executable."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$TESTMODEL" ]
|
||||
then
|
||||
echo "TESTMODEL variable does not link to a file. Please point it to your copy of test_model.py"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$OPUSDEMO" ]
|
||||
then
|
||||
echo "OPUSDEMO variable does not link to a file. Please point it to your patched version of opus_demo."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$LACE" ]
|
||||
then
|
||||
echo "LACE variable does not link to a file. Please point it to your copy of the LACE checkpoint."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$NOLACE" ]
|
||||
then
|
||||
echo "LACE variable does not link to a file. Please point it to your copy of the NOLACE checkpoint."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
case $# in
|
||||
2) INPUT=$1; OUTPUT=$2;;
|
||||
*) echo "process_dataset.sh <input folder> <output folder>"; exit 1;;
|
||||
esac
|
||||
|
||||
if [ -d $OUTPUT ]
|
||||
then
|
||||
echo "output folder $OUTPUT exists, aborting..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p $OUTPUT
|
||||
|
||||
if [ "$BITRATES" == "" ]
|
||||
then
|
||||
BITRATES=( 6000 7500 9000 12000 15000 18000 24000 32000 )
|
||||
echo "BITRATES variable not defined. Proceeding with default bitrates ${BITRATES[@]}."
|
||||
fi
|
||||
|
||||
|
||||
echo "LACE=${LACE}" > ${OUTPUT}/info.txt
|
||||
echo "NOLACE=${NOLACE}" >> ${OUTPUT}/info.txt
|
||||
|
||||
ITEMFILE=${OUTPUT}/items.txt
|
||||
BITRATEFILE=${OUTPUT}/bitrates.txt
|
||||
|
||||
FPROCESSING=${OUTPUT}/processing
|
||||
FCLEAN=${OUTPUT}/clean
|
||||
FOPUS=${OUTPUT}/opus
|
||||
FLACE=${OUTPUT}/lace
|
||||
FNOLACE=${OUTPUT}/nolace
|
||||
|
||||
mkdir -p $FPROCESSING $FCLEAN $FOPUS $FLACE $FNOLACE
|
||||
|
||||
echo "${BITRATES[@]}" > $BITRATEFILE
|
||||
|
||||
for fn in $(find $INPUT -type f -name "*.wav")
|
||||
do
|
||||
UUID=$(uuid)
|
||||
echo "$UUID $fn" >> $ITEMFILE
|
||||
PIDS=( )
|
||||
for br in ${BITRATES[@]}
|
||||
do
|
||||
# run opus
|
||||
pfolder=${FPROCESSING}/${UUID}_${br}
|
||||
mkdir -p $pfolder
|
||||
sox $fn -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16
|
||||
(cd ${pfolder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
|
||||
|
||||
# copy clean and opus
|
||||
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16 $FCLEAN/${UUID}_${br}_clean.wav
|
||||
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/noisy.s16 $FOPUS/${UUID}_${br}_opus.wav
|
||||
|
||||
# run LACE
|
||||
$PYTHON $TESTMODEL $pfolder $LACE $FLACE/${UUID}_${br}_lace.wav &
|
||||
PIDS+=( "$!" )
|
||||
|
||||
# run NoLACE
|
||||
$PYTHON $TESTMODEL $pfolder $NOLACE $FNOLACE/${UUID}_${br}_nolace.wav &
|
||||
PIDS+=( "$!" )
|
||||
done
|
||||
for pid in ${PIDS[@]}
|
||||
do
|
||||
wait $pid
|
||||
done
|
||||
done
|
||||
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
import pandas as pd
|
||||
from scipy.spatial.distance import cdist
|
||||
from scipy.io import wavfile
|
||||
import numpy as np
|
||||
|
||||
from nomad_audio.nomad import Nomad
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('folder', type=str, help='folder with processed items')
|
||||
parser.add_argument('--full-reference', action='store_true', help='use NOMAD as full-reference metric')
|
||||
parser.add_argument('--device', type=str, default=None, help='device for Nomad')
|
||||
|
||||
|
||||
def get_bitrates(folder):
|
||||
with open(os.path.join(folder, 'bitrates.txt')) as f:
|
||||
x = f.read()
|
||||
|
||||
bitrates = [int(y) for y in x.rstrip('\n').split()]
|
||||
|
||||
return bitrates
|
||||
|
||||
def get_itemlist(folder):
|
||||
with open(os.path.join(folder, 'items.txt')) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
items = [x.split()[0] for x in lines]
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def nomad_wrapper(ref_folder, deg_folder, full_reference=False, ref_embeddings=None, device=None):
|
||||
model = Nomad(device=device)
|
||||
if not full_reference:
|
||||
results = model.predict(nmr=ref_folder, deg=deg_folder)[0].to_dict()['NOMAD']
|
||||
return results, None
|
||||
else:
|
||||
if ref_embeddings is None:
|
||||
print(f"Computing reference embeddings from {ref_folder}")
|
||||
ref_data = pd.DataFrame(sorted(os.listdir(ref_folder)))
|
||||
ref_data.columns = ['filename']
|
||||
ref_data['filename'] = [os.path.join(ref_folder, x) for x in ref_data['filename']]
|
||||
ref_embeddings = model.get_embeddings_csv(model.model, ref_data).set_index('filename')
|
||||
|
||||
print(f"Computing degraded embeddings from {deg_folder}")
|
||||
deg_data = pd.DataFrame(sorted(os.listdir(deg_folder)))
|
||||
deg_data.columns = ['filename']
|
||||
deg_data['filename'] = [os.path.join(deg_folder, x) for x in deg_data['filename']]
|
||||
deg_embeddings = model.get_embeddings_csv(model.model, deg_data).set_index('filename')
|
||||
|
||||
dist = np.diag(cdist(ref_embeddings, deg_embeddings)) # wasteful
|
||||
test_files = [x.split('/')[-1].split('.')[0] for x in deg_embeddings.index]
|
||||
|
||||
results = dict(zip(test_files, dist))
|
||||
|
||||
return results, ref_embeddings
|
||||
|
||||
|
||||
|
||||
|
||||
def nomad_process_all(folder, full_reference=False, device=None):
|
||||
bitrates = get_bitrates(folder)
|
||||
items = get_itemlist(folder)
|
||||
with tempfile.TemporaryDirectory() as dir:
|
||||
cleandir = os.path.join(dir, 'clean')
|
||||
opusdir = os.path.join(dir, 'opus')
|
||||
lacedir = os.path.join(dir, 'lace')
|
||||
nolacedir = os.path.join(dir, 'nolace')
|
||||
|
||||
# prepare files
|
||||
for d in [cleandir, opusdir, lacedir, nolacedir]: os.makedirs(d)
|
||||
for br in bitrates:
|
||||
for item in items:
|
||||
for cond in ['clean', 'opus', 'lace', 'nolace']:
|
||||
shutil.copyfile(os.path.join(folder, cond, f"{item}_{br}_{cond}.wav"), os.path.join(dir, cond, f"{item}_{br}.wav"))
|
||||
|
||||
nomad_opus, ref_embeddings = nomad_wrapper(cleandir, opusdir, full_reference=full_reference, ref_embeddings=None)
|
||||
nomad_lace, ref_embeddings = nomad_wrapper(cleandir, lacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
|
||||
nomad_nolace, ref_embeddings = nomad_wrapper(cleandir, nolacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
|
||||
|
||||
results = dict()
|
||||
for br in bitrates:
|
||||
results[br] = np.zeros((len(items), 3))
|
||||
for i, item in enumerate(items):
|
||||
key = f"{item}_{br}"
|
||||
results[br][i, 0] = nomad_opus[key]
|
||||
results[br][i, 1] = nomad_lace[key]
|
||||
results[br][i, 2] = nomad_nolace[key]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
items = get_itemlist(args.folder)
|
||||
bitrates = get_bitrates(args.folder)
|
||||
|
||||
results = nomad_process_all(args.folder, full_reference=args.full_reference, device=args.device)
|
||||
|
||||
np.save(os.path.join(args.folder, f'results_nomad.npy'), results)
|
||||
|
||||
print("Done.")
|
||||
@@ -0,0 +1,196 @@
|
||||
import os
|
||||
import argparse
|
||||
import yaml
|
||||
import subprocess
|
||||
|
||||
import numpy as np
|
||||
|
||||
from moc2 import compare as moc
|
||||
|
||||
DEBUG=False
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('inputdir', type=str, help='Input folder with test items')
|
||||
parser.add_argument('outputdir', type=str, help='Output folder')
|
||||
parser.add_argument('bitrate', type=int, help='bitrate to test')
|
||||
parser.add_argument('--reference_opus_demo', type=str, default='./opus_demo', help='reference opus_demo binary for generating bitstreams and reference output')
|
||||
parser.add_argument('--encoder_options', type=str, default="", help='encoder options (e.g. -complexity 5)')
|
||||
parser.add_argument('--test_opus_demo', type=str, default='./opus_demo', help='opus_demo binary under test')
|
||||
parser.add_argument('--test_opus_demo_options', type=str, default='-dec_complexity 7', help='options for test opus_demo (e.g. "-dec_complexity 7")')
|
||||
parser.add_argument('--verbose', type=int, default=0, help='verbosity level: 0 for quiet (default), 1 for reporting individual test results, 2 for reporting per-item scores in failed tests')
|
||||
|
||||
def run_opus_encoder(opus_demo_path, input_pcm_path, bitstream_path, application, fs, num_channels, bitrate, options=[], verbose=False):
|
||||
|
||||
call_args = [
|
||||
opus_demo_path,
|
||||
"-e",
|
||||
application,
|
||||
str(fs),
|
||||
str(num_channels),
|
||||
str(bitrate),
|
||||
"-bandwidth",
|
||||
"WB"
|
||||
]
|
||||
|
||||
call_args += options
|
||||
|
||||
call_args += [
|
||||
input_pcm_path,
|
||||
bitstream_path
|
||||
]
|
||||
|
||||
try:
|
||||
if verbose:
|
||||
print(f"running {call_args}...")
|
||||
subprocess.run(call_args)
|
||||
else:
|
||||
subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
except:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def run_opus_decoder(opus_demo_path, bitstream_path, output_pcm_path, fs, num_channels, options=[], verbose=False):
|
||||
|
||||
call_args = [
|
||||
opus_demo_path,
|
||||
"-d",
|
||||
str(fs),
|
||||
str(num_channels)
|
||||
]
|
||||
|
||||
call_args += options
|
||||
|
||||
call_args += [
|
||||
bitstream_path,
|
||||
output_pcm_path
|
||||
]
|
||||
|
||||
try:
|
||||
if verbose:
|
||||
print(f"running {call_args}...")
|
||||
subprocess.run(call_args)
|
||||
else:
|
||||
subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
except:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
def compute_moc_score(reference_pcm, test_pcm, delay=91):
|
||||
x_ref = np.fromfile(reference_pcm, dtype=np.int16).astype(np.float32) / (2 ** 15)
|
||||
x_cut = np.fromfile(test_pcm, dtype=np.int16).astype(np.float32) / (2 ** 15)
|
||||
|
||||
moc_score = moc(x_ref, x_cut[delay:])
|
||||
|
||||
return moc_score
|
||||
|
||||
def sox(*call_args):
|
||||
try:
|
||||
call_args = ["sox"] + list(call_args)
|
||||
subprocess.run(call_args)
|
||||
return 0
|
||||
except:
|
||||
return 1
|
||||
|
||||
def process_clip_factory(ref_opus_demo, test_opus_demo, enc_options, test_options):
|
||||
def process_clip(clip_path, processdir, bitrate):
|
||||
# derive paths
|
||||
clipname = os.path.splitext(os.path.split(clip_path)[1])[0]
|
||||
pcm_path = os.path.join(processdir, clipname + ".raw")
|
||||
bitstream_path = os.path.join(processdir, clipname + ".bin")
|
||||
ref_path = os.path.join(processdir, clipname + "_ref.raw")
|
||||
test_path = os.path.join(processdir, clipname + "_test.raw")
|
||||
|
||||
# run sox
|
||||
sox(clip_path, pcm_path)
|
||||
|
||||
# run encoder
|
||||
run_opus_encoder(ref_opus_demo, pcm_path, bitstream_path, "voip", 16000, 1, bitrate, enc_options)
|
||||
|
||||
# run decoder
|
||||
run_opus_decoder(ref_opus_demo, bitstream_path, ref_path, 16000, 1)
|
||||
run_opus_decoder(test_opus_demo, bitstream_path, test_path, 16000, 1, options=test_options)
|
||||
|
||||
d_ref = compute_moc_score(pcm_path, ref_path)
|
||||
d_test = compute_moc_score(pcm_path, test_path)
|
||||
|
||||
return d_ref, d_test
|
||||
|
||||
|
||||
return process_clip
|
||||
|
||||
def main(inputdir, outputdir, bitrate, reference_opus_demo, test_opus_demo, enc_option_string, test_option_string, verbose):
|
||||
|
||||
# load clips list
|
||||
with open(os.path.join(inputdir, 'clips.yml'), "r") as f:
|
||||
clips = yaml.safe_load(f)
|
||||
|
||||
# parse test options
|
||||
enc_options = enc_option_string.split()
|
||||
test_options = test_option_string.split()
|
||||
|
||||
process_clip = process_clip_factory(reference_opus_demo, test_opus_demo, enc_options, test_options)
|
||||
|
||||
os.makedirs(outputdir, exist_ok=True)
|
||||
processdir = os.path.join(outputdir, 'process')
|
||||
os.makedirs(processdir, exist_ok=True)
|
||||
|
||||
num_passed = 0
|
||||
results = dict()
|
||||
min_rel_diff = 1000
|
||||
min_mean = 1000
|
||||
worst_clip = None
|
||||
worst_lang = None
|
||||
for lang, lang_clips in clips.items():
|
||||
if verbose > 0: print(f"processing language {lang}...")
|
||||
results[lang] = np.zeros((len(lang_clips), 2))
|
||||
for i, clip in enumerate(lang_clips):
|
||||
clip_path = os.path.join(inputdir, clip)
|
||||
d_ref, d_test = process_clip(clip_path, processdir, bitrate)
|
||||
results[lang][i, 0] = d_ref
|
||||
results[lang][i, 1] = d_test
|
||||
|
||||
alpha = 0.5
|
||||
rel_diff = ((results[lang][:, 0] ** alpha - results[lang][:, 1] ** alpha) /(results[lang][:, 0] ** alpha))
|
||||
|
||||
min_idx = np.argmin(rel_diff).item()
|
||||
if rel_diff[min_idx] < min_rel_diff:
|
||||
min_rel_diff = rel_diff[min_idx]
|
||||
worst_clip = lang_clips[min_idx]
|
||||
|
||||
if np.mean(rel_diff) < min_mean:
|
||||
min_mean = np.mean(rel_diff).item()
|
||||
worst_lang = lang
|
||||
|
||||
if np.min(rel_diff) < -0.1 or np.mean(rel_diff) < -0.025:
|
||||
if verbose > 0: print(f"FAIL ({np.mean(results[lang], axis=0)} {np.mean(rel_diff)} {np.min(rel_diff)})")
|
||||
if verbose > 1:
|
||||
for i, c in enumerate(lang_clips):
|
||||
print(f" {c:50s} {results[lang][i]} {rel_diff[i]}")
|
||||
else:
|
||||
if verbose > 0: print(f"PASS ({np.mean(results[lang], axis=0)} {np.mean(rel_diff)} {np.min(rel_diff)})")
|
||||
num_passed += 1
|
||||
|
||||
print(f"{num_passed}/{len(clips)} tests passed!")
|
||||
|
||||
print(f"worst case occured at clip {worst_clip} with relative difference of {min_rel_diff}")
|
||||
print(f"worst mean relative difference was {min_mean} for test {worst_lang}")
|
||||
|
||||
np.save(os.path.join(outputdir, f'results_' + "_".join(test_options) + f"_{bitrate}.npy"), results, allow_pickle=True)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.inputdir,
|
||||
args.outputdir,
|
||||
args.bitrate,
|
||||
args.reference_opus_demo,
|
||||
args.test_opus_demo,
|
||||
args.encoder_options,
|
||||
args.test_opus_demo_options,
|
||||
args.verbose)
|
||||
@@ -0,0 +1,205 @@
|
||||
""" module for inspecting models during inference """
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation as animation
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
|
||||
_state = dict()
|
||||
_folder = 'endoscopy'
|
||||
|
||||
def get_gru_gates(gru, input, state):
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
|
||||
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
|
||||
|
||||
# reset gate
|
||||
start, stop = 0 * hidden_size, 1 * hidden_size
|
||||
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# update gate
|
||||
start, stop = 1 * hidden_size, 2 * hidden_size
|
||||
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
|
||||
|
||||
# new gate
|
||||
start, stop = 2 * hidden_size, 3 * hidden_size
|
||||
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
|
||||
|
||||
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
|
||||
|
||||
|
||||
def init(folder='endoscopy'):
|
||||
""" sets up output folder for endoscopy data """
|
||||
|
||||
global _folder
|
||||
_folder = folder
|
||||
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
else:
|
||||
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
|
||||
|
||||
def write_data(key, data, fs):
|
||||
""" appends data to previous data written under key """
|
||||
|
||||
global _state
|
||||
|
||||
# convert to numpy if torch.Tensor is given
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.detach().numpy()
|
||||
|
||||
if not key in _state:
|
||||
_state[key] = {
|
||||
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
|
||||
'fs' : fs,
|
||||
'dim' : tuple(data.shape),
|
||||
'dtype' : str(data.dtype)
|
||||
}
|
||||
|
||||
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
|
||||
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
|
||||
else:
|
||||
if _state[key]['fs'] != fs:
|
||||
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
|
||||
if _state[key]['dtype'] != str(data.dtype):
|
||||
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
|
||||
if _state[key]['dim'] != tuple(data.shape):
|
||||
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
|
||||
|
||||
_state[key]['fid'].write(data.tobytes())
|
||||
|
||||
def close(folder='endoscopy'):
|
||||
""" clean up """
|
||||
for key in _state.keys():
|
||||
_state[key]['fid'].close()
|
||||
|
||||
|
||||
def read_data(folder='endoscopy'):
|
||||
""" retrieves written data as numpy arrays """
|
||||
|
||||
|
||||
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
|
||||
|
||||
return_dict = dict()
|
||||
|
||||
for key in keys:
|
||||
with open(os.path.join(folder, key + '.yml'), 'r') as f:
|
||||
value = yaml.load(f.read(), yaml.FullLoader)
|
||||
|
||||
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
|
||||
data = np.frombuffer(f.read(), dtype=value['dtype'])
|
||||
|
||||
value['data'] = data.reshape((-1,) + value['dim'])
|
||||
|
||||
return_dict[key] = value
|
||||
|
||||
return return_dict
|
||||
|
||||
def get_best_reshape(shape, target_ratio=1):
|
||||
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
|
||||
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return (1,)
|
||||
|
||||
num_columns = int((pixel_count / target_ratio)**.5)
|
||||
|
||||
while (pixel_count % num_columns):
|
||||
num_columns -= 1
|
||||
|
||||
num_rows = pixel_count // num_columns
|
||||
|
||||
return (num_rows, num_columns)
|
||||
|
||||
def get_type_and_shape(shape):
|
||||
|
||||
# can happen if data is one dimensional
|
||||
if len(shape) == 0:
|
||||
shape = (1,)
|
||||
|
||||
# calculate pixel count
|
||||
if len(shape) > 1:
|
||||
pixel_count = 1
|
||||
for s in shape:
|
||||
pixel_count *= s
|
||||
else:
|
||||
pixel_count = shape[0]
|
||||
|
||||
if pixel_count == 1:
|
||||
return 'plot', (1, )
|
||||
|
||||
# stay with shape if already 2-dimensional
|
||||
if len(shape) == 2:
|
||||
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
|
||||
return 'image', shape
|
||||
|
||||
return 'image', get_best_reshape(shape)
|
||||
|
||||
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
|
||||
|
||||
# determine plot setup
|
||||
num_keys = len(data.keys())
|
||||
|
||||
num_rows = int((num_keys * 3/4) ** .5)
|
||||
|
||||
num_cols = (num_keys + num_rows - 1) // num_rows
|
||||
|
||||
fig, axs = plt.subplots(num_rows, num_cols)
|
||||
fig.set_size_inches(num_cols * 5, num_rows * 5)
|
||||
|
||||
display = dict()
|
||||
|
||||
fs_max = max([val['fs'] for val in data.values()])
|
||||
|
||||
num_samples = max([val['data'].shape[0] for val in data.values()])
|
||||
|
||||
keys = sorted(data.keys())
|
||||
|
||||
# inspect data
|
||||
for i, key in enumerate(keys):
|
||||
axs[i // num_cols, i % num_cols].title.set_text(key)
|
||||
|
||||
display[key] = dict()
|
||||
|
||||
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
|
||||
display[key]['down_factor'] = data[key]['fs'] / fs_max
|
||||
|
||||
start_index = max(start_index, half_signal_window_length)
|
||||
while stop_index < 0:
|
||||
stop_index += num_samples
|
||||
|
||||
stop_index = min(stop_index, num_samples - half_signal_window_length)
|
||||
|
||||
# actual plotting
|
||||
frames = []
|
||||
for index in range(start_index, stop_index):
|
||||
ims = []
|
||||
for i, key in enumerate(keys):
|
||||
feature_index = int(round(index * display[key]['down_factor']))
|
||||
|
||||
if display[key]['type'] == 'plot':
|
||||
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
|
||||
|
||||
elif display[key]['type'] == 'image':
|
||||
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
|
||||
|
||||
frames.append(ims)
|
||||
|
||||
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
|
||||
|
||||
if not filename.endswith('.mp4'):
|
||||
filename += '.mp4'
|
||||
|
||||
ani.save(filename)
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,25 @@
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation
|
||||
|
||||
def make_playback_animation(savepath, spec, duration_ms, vmin=20, vmax=90):
|
||||
fig, axs = plt.subplots()
|
||||
axs.set_axis_off()
|
||||
fig.set_size_inches((duration_ms / 1000 * 5, 5))
|
||||
frames = []
|
||||
frame_duration=20
|
||||
num_frames = int(duration_ms / frame_duration + .99)
|
||||
|
||||
spec_height, spec_width = spec.shape
|
||||
for i in range(num_frames):
|
||||
xpos = (i - 1) / (num_frames - 3) * (spec_width - 1)
|
||||
new_frame = axs.imshow(spec, cmap='inferno', origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
|
||||
if i in {0, num_frames - 1}:
|
||||
frames.append([new_frame])
|
||||
else:
|
||||
line = axs.plot([xpos, xpos], [0, spec_height-1], color='white', alpha=0.8)[0]
|
||||
frames.append([new_frame, line])
|
||||
|
||||
|
||||
ani = matplotlib.animation.ArtistAnimation(fig, frames, blit=True, interval=frame_duration)
|
||||
ani.save(savepath, dpi=720)
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user