add some code

This commit is contained in:
2025-09-05 13:25:11 +08:00
parent 9ff0a99e7a
commit 3cf1229a85
8911 changed files with 2535396 additions and 0 deletions

View File

@@ -0,0 +1,123 @@
import argparse
import os
import yaml
import subprocess
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('commonvoice_base_dir')
parser.add_argument('output_dir')
parser.add_argument('--clips-per-language', required=False, type=int, default=10)
parser.add_argument('--seed', required=False, type=int, default=2024)
def select_clips(dir, num_clips=10):
if num_clips % 2:
print(f"warning: number of clips will be reduced to {num_clips - 1}")
female = dict()
male = dict()
clips = np.genfromtxt(os.path.join(dir, 'validated.tsv'), delimiter='\t', dtype=str, invalid_raise=False)
clips_by_client = dict()
if len(clips.shape) < 2 or len(clips) < num_clips:
# not enough data to proceed
return None
for client in set(clips[1:,0]):
client_clips = clips[clips[:, 0] == client]
f, m = False, False
if 'female_feminine' in client_clips[:, 8]:
female[client] = client_clips[client_clips[:, 8] == 'female_feminine']
f = True
if 'male_masculine' in client_clips[:, 8]:
male[client] = client_clips[client_clips[:, 8] == 'male_masculine']
m = True
if f and m:
print(f"both male and female clips under client {client}")
if min(len(female), len(male)) < num_clips // 2:
return None
# select num_clips // 2 random female clients
female_client_selection = np.array(list(female.keys()), dtype=str)[np.random.choice(len(female), num_clips//2, replace=False)]
female_clip_selection = []
for c in female_client_selection:
s_idx = np.random.randint(0, len(female[c]))
female_clip_selection.append(os.path.join(dir, 'clips', female[c][s_idx, 1].item()))
# select num_clips // 2 random female clients
male_client_selection = np.array(list(male.keys()), dtype=str)[np.random.choice(len(male), num_clips//2, replace=False)]
male_clip_selection = []
for c in male_client_selection:
s_idx = np.random.randint(0, len(male[c]))
male_clip_selection.append(os.path.join(dir, 'clips', male[c][s_idx, 1].item()))
return female_clip_selection + male_clip_selection
def ffmpeg_available():
try:
x = subprocess.run(['ffmpeg', '-h'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return x.returncode == 0
except:
return False
def convert_clips(selection, outdir):
if not ffmpeg_available():
raise RuntimeError("ffmpeg not available")
clipdir = os.path.join(outdir, 'clips')
os.makedirs(clipdir, exist_ok=True)
clipdict = dict()
for lang, clips in selection.items():
clipdict[lang] = []
for clip in clips:
clipname = os.path.splitext(os.path.split(clip)[-1])[0]
target_name = os.path.join('clips', clipname + '.wav')
call_args = ['ffmpeg', '-i', clip, '-ar', '16000', os.path.join(outdir, target_name)]
print(call_args)
r = subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
if r.returncode != 0:
raise RuntimeError(f'could not execute {call_args}')
clipdict[lang].append(target_name)
return clipdict
if __name__ == "__main__":
if not ffmpeg_available():
raise RuntimeError("ffmpeg not available")
args = parser.parse_args()
base_dir = args.commonvoice_base_dir
output_dir = args.output_dir
seed = args.seed
np.random.seed(seed)
langs = os.listdir(base_dir)
selection = dict()
for lang in langs:
print(f"processing {lang}...")
clips = select_clips(os.path.join(base_dir, lang))
if clips is not None:
selection[lang] = clips
os.makedirs(output_dir, exist_ok=True)
clips = convert_clips(selection, output_dir)
with open(os.path.join(output_dir, 'clips.yml'), 'w') as f:
yaml.dump(clips, f)

View File

@@ -0,0 +1,25 @@
#!/bin/bash
INPUT="dataset/LibriSpeech"
OUTPUT="testdata"
OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"
BITRATES=( 6000 7500 ) # 9000 12000 15000 18000 24000 32000 )
mkdir -p $OUTPUT
for fn in $(find $INPUT -name "*.wav")
do
name=$(basename ${fn%*.wav})
sox $fn -r 16000 -b 16 -e signed-integer ${OUTPUT}/tmp.raw
for br in ${BITRATES[@]}
do
folder=${OUTPUT}/"${name}_${br}.se"
echo "creating ${folder}..."
mkdir -p $folder
cp ${OUTPUT}/tmp.raw ${folder}/clean.s16
(cd ${folder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
done
rm -f ${OUTPUT}/tmp.raw
done

View File

@@ -0,0 +1,7 @@
#!/bin/bash
export PYTHON=/home/ubuntu/opt/miniconda3/envs/torch/bin/python
export LACE="/local/experiments/ietf_enhancement_studies/checkpoints/lace_checkpoint.pth"
export NOLACE="/local/experiments/ietf_enhancement_studies/checkpoints/nolace_checkpoint.pth"
export TESTMODEL="/local/experiments/ietf_enhancement_studies/opus/dnn/torch/osce/test_model.py"
export OPUSDEMO="/local/experiments/ietf_enhancement_studies/bin/opus_demo_patched"

View File

@@ -0,0 +1,113 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
from scipy.io import wavfile
from pesq import pesq
import numpy as np
from moc import compare
from moc2 import compare as compare2
#from warpq import compute_WAPRQ as warpq
from lace_loss_metric import compare as laceloss_compare
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='folder with processed items')
parser.add_argument('metric', type=str, choices=['pesq', 'moc', 'moc2', 'laceloss'], help='metric to be used for evaluation')
def get_bitrates(folder):
with open(os.path.join(folder, 'bitrates.txt')) as f:
x = f.read()
bitrates = [int(y) for y in x.rstrip('\n').split()]
return bitrates
def get_itemlist(folder):
with open(os.path.join(folder, 'items.txt')) as f:
lines = f.readlines()
items = [x.split()[0] for x in lines]
return items
def process_item(folder, item, bitrate, metric):
fs, x_clean = wavfile.read(os.path.join(folder, 'clean', f"{item}_{bitrate}_clean.wav"))
fs, x_opus = wavfile.read(os.path.join(folder, 'opus', f"{item}_{bitrate}_opus.wav"))
fs, x_lace = wavfile.read(os.path.join(folder, 'lace', f"{item}_{bitrate}_lace.wav"))
fs, x_nolace = wavfile.read(os.path.join(folder, 'nolace', f"{item}_{bitrate}_nolace.wav"))
x_clean = x_clean.astype(np.float32) / 2**15
x_opus = x_opus.astype(np.float32) / 2**15
x_lace = x_lace.astype(np.float32) / 2**15
x_nolace = x_nolace.astype(np.float32) / 2**15
if metric == 'pesq':
result = [pesq(fs, x_clean, x_opus), pesq(fs, x_clean, x_lace), pesq(fs, x_clean, x_nolace)]
elif metric =='moc':
result = [compare(x_clean, x_opus), compare(x_clean, x_lace), compare(x_clean, x_nolace)]
elif metric =='moc2':
result = [compare2(x_clean, x_opus), compare2(x_clean, x_lace), compare2(x_clean, x_nolace)]
# elif metric == 'warpq':
# result = [warpq(x_clean, x_opus), warpq(x_clean, x_lace), warpq(x_clean, x_nolace)]
elif metric == 'laceloss':
result = [laceloss_compare(x_clean, x_opus), laceloss_compare(x_clean, x_lace), laceloss_compare(x_clean, x_nolace)]
else:
raise ValueError(f'unknown metric {metric}')
return result
def process_bitrate(folder, items, bitrate, metric):
results = np.zeros((len(items), 3))
for i, item in enumerate(items):
results[i, :] = np.array(process_item(folder, item, bitrate, metric))
return results
if __name__ == "__main__":
args = parser.parse_args()
items = get_itemlist(args.folder)
bitrates = get_bitrates(args.folder)
results = dict()
for br in bitrates:
print(f"processing bitrate {br}...")
results[br] = process_bitrate(args.folder, items, br, args.metric)
np.save(os.path.join(args.folder, f'results_{args.metric}.npy'), results)
print("Done.")

View File

@@ -0,0 +1,330 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
"""STFT-based Loss modules."""
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import torchaudio
def get_window(win_name, win_length, *args, **kwargs):
window_dict = {
'bartlett_window' : torch.bartlett_window,
'blackman_window' : torch.blackman_window,
'hamming_window' : torch.hamming_window,
'hann_window' : torch.hann_window,
'kaiser_window' : torch.kaiser_window
}
if not win_name in window_dict:
raise ValueError()
return window_dict[win_name](win_length, *args, **kwargs)
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
win = get_window(window, win_length).to(x.device)
x_stft = torch.stft(x, fft_size, hop_size, win_length, win, return_complex=True)
return torch.clamp(torch.abs(x_stft), min=1e-7)
def spectral_convergence_loss(Y_true, Y_pred):
dims=list(range(1, len(Y_pred.shape)))
return torch.mean(torch.norm(torch.abs(Y_true) - torch.abs(Y_pred), p="fro", dim=dims) / (torch.norm(Y_pred, p="fro", dim=dims) + 1e-6))
def log_magnitude_loss(Y_true, Y_pred):
Y_true_log_abs = torch.log(torch.abs(Y_true) + 1e-15)
Y_pred_log_abs = torch.log(torch.abs(Y_pred) + 1e-15)
return torch.mean(torch.abs(Y_true_log_abs - Y_pred_log_abs))
def spectral_xcorr_loss(Y_true, Y_pred):
Y_true = Y_true.abs()
Y_pred = Y_pred.abs()
dims=list(range(1, len(Y_pred.shape)))
xcorr = torch.sum(Y_true * Y_pred, dim=dims) / torch.sqrt(torch.sum(Y_true ** 2, dim=dims) * torch.sum(Y_pred ** 2, dim=dims) + 1e-9)
return 1 - xcorr.mean()
class MRLogMelLoss(nn.Module):
def __init__(self,
fft_sizes=[512, 256, 128, 64],
overlap=0.5,
fs=16000,
n_mels=18
):
self.fft_sizes = fft_sizes
self.overlap = overlap
self.fs = fs
self.n_mels = n_mels
super().__init__()
self.mel_specs = []
for fft_size in fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
n_mels = self.n_mels
if fft_size < 128:
n_mels //= 2
self.mel_specs.append(torchaudio.transforms.MelSpectrogram(fs, fft_size, hop_length=hop_size, n_mels=n_mels))
for i, mel_spec in enumerate(self.mel_specs):
self.add_module(f'mel_spec_{i+1}', mel_spec)
def forward(self, y_true, y_pred):
loss = torch.zeros(1, device=y_true.device)
for mel_spec in self.mel_specs:
Y_true = mel_spec(y_true)
Y_pred = mel_spec(y_pred)
loss = loss + log_magnitude_loss(Y_true, Y_pred)
loss = loss / len(self.mel_specs)
return loss
def create_weight_matrix(num_bins, bins_per_band=10):
m = torch.zeros((num_bins, num_bins), dtype=torch.float32)
r0 = bins_per_band // 2
r1 = bins_per_band - r0
for i in range(num_bins):
i0 = max(i - r0, 0)
j0 = min(i + r1, num_bins)
m[i, i0: j0] += 1
if i < r0:
m[i, :r0 - i] += 1
if i > num_bins - r1:
m[i, num_bins - r1 - i:] += 1
return m / bins_per_band
def weighted_spectral_convergence(Y_true, Y_pred, w):
# calculate sfm based weights
logY = torch.log(torch.abs(Y_true) + 1e-9)
Y = torch.abs(Y_true)
avg_logY = torch.matmul(logY.transpose(1, 2), w)
avg_Y = torch.matmul(Y.transpose(1, 2), w)
sfm = torch.exp(avg_logY) / (avg_Y + 1e-9)
weight = (torch.relu(1 - sfm) ** .5).transpose(1, 2)
loss = torch.mean(
torch.mean(weight * torch.abs(torch.abs(Y_true) - torch.abs(Y_pred)), dim=[1, 2])
/ (torch.mean( weight * torch.abs(Y_true), dim=[1, 2]) + 1e-9)
)
return loss
def gen_filterbank(N, Fs=16000):
in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:]
out_freq = (np.arange(N, dtype='float32')/N*Fs/2)[:,None]
#ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73.
ERB_N = 24.7 + .108*in_freq
delta = np.abs(in_freq-out_freq)/ERB_N
center = (delta<.5).astype('float32')
R = -12*center*delta**2 + (1-center)*(3-12*delta)
RE = 10.**(R/10.)
norm = np.sum(RE, axis=1)
RE = RE/norm[:, np.newaxis]
return torch.from_numpy(RE)
def smooth_log_mag(Y_true, Y_pred, filterbank):
Y_true_smooth = torch.matmul(filterbank, torch.abs(Y_true))
Y_pred_smooth = torch.matmul(filterbank, torch.abs(Y_pred))
loss = torch.abs(
torch.log(Y_true_smooth + 1e-9) - torch.log(Y_pred_smooth + 1e-9)
)
loss = loss.mean()
return loss
class MRSTFTLoss(nn.Module):
def __init__(self,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
overlap=0.5,
window='hann_window',
fs=16000,
log_mag_weight=0,
sc_weight=0,
wsc_weight=0,
smooth_log_mag_weight=2,
sxcorr_weight=1):
super().__init__()
self.fft_sizes = fft_sizes
self.overlap = overlap
self.window = window
self.log_mag_weight = log_mag_weight
self.sc_weight = sc_weight
self.wsc_weight = wsc_weight
self.smooth_log_mag_weight = smooth_log_mag_weight
self.sxcorr_weight = sxcorr_weight
self.fs = fs
# weights for SFM weighted spectral convergence loss
self.wsc_weights = torch.nn.ParameterDict()
for fft_size in fft_sizes:
width = min(11, int(1000 * fft_size / self.fs + .5))
width += width % 2
self.wsc_weights[str(fft_size)] = torch.nn.Parameter(
create_weight_matrix(fft_size // 2 + 1, width),
requires_grad=False
)
# filterbanks for smooth log magnitude loss
self.filterbanks = torch.nn.ParameterDict()
for fft_size in fft_sizes:
self.filterbanks[str(fft_size)] = torch.nn.Parameter(
gen_filterbank(fft_size//2),
requires_grad=False
)
def __call__(self, y_true, y_pred):
lm_loss = torch.zeros(1, device=y_true.device)
sc_loss = torch.zeros(1, device=y_true.device)
wsc_loss = torch.zeros(1, device=y_true.device)
slm_loss = torch.zeros(1, device=y_true.device)
sxcorr_loss = torch.zeros(1, device=y_true.device)
for fft_size in self.fft_sizes:
hop_size = int(round(fft_size * (1 - self.overlap)))
win_size = fft_size
Y_true = stft(y_true, fft_size, hop_size, win_size, self.window)
Y_pred = stft(y_pred, fft_size, hop_size, win_size, self.window)
if self.log_mag_weight > 0:
lm_loss = lm_loss + log_magnitude_loss(Y_true, Y_pred)
if self.sc_weight > 0:
sc_loss = sc_loss + spectral_convergence_loss(Y_true, Y_pred)
if self.wsc_weight > 0:
wsc_loss = wsc_loss + weighted_spectral_convergence(Y_true, Y_pred, self.wsc_weights[str(fft_size)])
if self.smooth_log_mag_weight > 0:
slm_loss = slm_loss + smooth_log_mag(Y_true, Y_pred, self.filterbanks[str(fft_size)])
if self.sxcorr_weight > 0:
sxcorr_loss = sxcorr_loss + spectral_xcorr_loss(Y_true, Y_pred)
total_loss = (self.log_mag_weight * lm_loss + self.sc_weight * sc_loss
+ self.wsc_weight * wsc_loss + self.smooth_log_mag_weight * slm_loss
+ self.sxcorr_weight * sxcorr_loss) / len(self.fft_sizes)
return total_loss
def td_l2_norm(y_true, y_pred):
dims = list(range(1, len(y_true.shape)))
loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
return loss.mean()
class LaceLoss(nn.Module):
def __init__(self):
super().__init__()
self.stftloss = MRSTFTLoss(log_mag_weight=0, sc_weight=0, wsc_weight=0, smooth_log_mag_weight=2, sxcorr_weight=1)
def forward(self, x, y):
specloss = self.stftloss(x, y)
phaseloss = td_l2_norm(x, y)
total_loss = (specloss + 10 * phaseloss) / 13
return total_loss
def compare(self, x_ref, x_deg):
# trim items to same size
n = min(len(x_ref), len(x_deg))
x_ref = x_ref[:n].copy()
x_deg = x_deg[:n].copy()
# pre-emphasis
x_ref[1:] -= 0.85 * x_ref[:-1]
x_deg[1:] -= 0.85 * x_deg[:-1]
device = next(iter(self.parameters())).device
x = torch.from_numpy(x_ref).to(device)
y = torch.from_numpy(x_deg).to(device)
with torch.no_grad():
dist = 10 * self.forward(x, y)
return dist.cpu().numpy().item()
lace_loss = LaceLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lace_loss.to(device)
def compare(x, y):
return lace_loss.compare(x, y)

View File

@@ -0,0 +1,116 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def plot_data(filename, data, title=None):
compare_dict = dict()
for br in data.keys():
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
compare_dict[f'LACE {br/1000:.1f} kb/s'] = data[br][:, 1]
compare_dict[f'NoLACE {br/1000:.1f} kb/s'] = data[br][:, 2]
plt.rcParams.update({
"text.usetex": True,
"font.family": "Helvetica",
"font.size": 32
})
black = '#000000'
red = '#ff5745'
blue = '#007dbc'
colors = [black, red, blue]
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
Patch(facecolor=colors[1], label='LACE'),
Patch(facecolor=colors[2], label='NoLACE')]
fig, ax = plt.subplots()
fig.set_size_inches(40, 20)
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
for i, patch in enumerate(bplot['boxes']):
patch.set_facecolor(colors[i%3])
ax.set_xticklabels(compare_dict.keys(), rotation=290)
if title is not None:
ax.set_title(title)
ax.legend(handles=legend_elements)
fig.savefig(filename, bbox_inches='tight')
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
print("Done.")

View File

@@ -0,0 +1,109 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def plot_data(filename, data, title=None):
compare_dict = dict()
for br in data.keys():
compare_dict[f'Opus {br/1000:.1f} kb/s'] = data[br][:, 0]
compare_dict[f'LACE (MOC only) {br/1000:.1f} kb/s'] = data[br][:, 1]
compare_dict[f'LACE (MOC + TD) {br/1000:.1f} kb/s'] = data[br][:, 2]
plt.rcParams.update({
"text.usetex": True,
"font.family": "Helvetica",
"font.size": 32
})
colors = ['pink', 'lightblue', 'lightgreen']
legend_elements = [Patch(facecolor=colors[0], label='Opus SILK'),
Patch(facecolor=colors[1], label='MOC loss only'),
Patch(facecolor=colors[2], label='MOC + TD loss')]
fig, ax = plt.subplots()
fig.set_size_inches(40, 20)
bplot = ax.boxplot(compare_dict.values(), showfliers=False, notch=True, patch_artist=True)
for i, patch in enumerate(bplot['boxes']):
patch.set_facecolor(colors[i%3])
ax.set_xticklabels(compare_dict.keys(), rotation=290)
if title is not None:
ax.set_title(title)
ax.legend(handles=legend_elements)
fig.savefig(filename, bbox_inches='tight')
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
plot_data(os.path.join(folder, f"boxplot_{metric}.png"), data[metric], title=metric.upper())
print("Done.")

View File

@@ -0,0 +1,124 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_moc2.npy')):
data['moc2'] = np.load(os.path.join(folder, 'results_moc2.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def make_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1]
nolace = data[br][:, 2]
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
def make_diff_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1] - opus
nolace = data[br][:, 2] - opus
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
print("Done.")

View File

@@ -0,0 +1,121 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from matplotlib.patches import Patch
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='path to folder with pre-calculated metrics')
parser.add_argument('--metric', choices=['pesq', 'moc', 'warpq', 'nomad', 'laceloss', 'all'], default='all', help='default: all')
parser.add_argument('--output', type=str, default=None, help='alternative output folder, default: folder')
def load_data(folder):
data = dict()
if os.path.isfile(os.path.join(folder, 'results_moc.npy')):
data['moc'] = np.load(os.path.join(folder, 'results_moc.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_pesq.npy')):
data['pesq'] = np.load(os.path.join(folder, 'results_pesq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_warpq.npy')):
data['warpq'] = np.load(os.path.join(folder, 'results_warpq.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_nomad.npy')):
data['nomad'] = np.load(os.path.join(folder, 'results_nomad.npy'), allow_pickle=True).item()
if os.path.isfile(os.path.join(folder, 'results_laceloss.npy')):
data['laceloss'] = np.load(os.path.join(folder, 'results_laceloss.npy'), allow_pickle=True).item()
return data
def make_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'Opus', 'LACE', 'NoLACE']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1]
nolace = data[br][:, 2]
tbl.add_row([br, f"{float(opus.mean()):.3f} ({float(opus.std()):.2f})", f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
def make_diff_table(filename, data, title=None):
# mean values
tbl = PrettyTable()
tbl.field_names = ['bitrate (bps)', 'LACE - Opus', 'NoLACE - Opus']
for br in data.keys():
opus = data[br][:, 0]
lace = data[br][:, 1] - opus
nolace = data[br][:, 2] - opus
tbl.add_row([br, f"{float(lace.mean()):.3f} ({float(lace.std()):.2f})", f"{float(nolace.mean()):.3f} ({float(nolace.std()):.2f})"])
with open(filename + ".txt", "w") as f:
f.write(str(tbl))
with open(filename + ".html", "w") as f:
f.write(tbl.get_html_string())
with open(filename + ".csv", "w") as f:
f.write(tbl.get_csv_string())
print(tbl)
if __name__ == "__main__":
args = parser.parse_args()
data = load_data(args.folder)
metrics = list(data.keys()) if args.metric == 'all' else [args.metric]
folder = args.folder if args.output is None else args.output
os.makedirs(folder, exist_ok=True)
for metric in metrics:
print(f"Plotting data for {metric} metric...")
make_table(os.path.join(folder, f"table_{metric}"), data[metric])
make_diff_table(os.path.join(folder, f"table_diff_{metric}"), data[metric])
print("Done.")

View File

@@ -0,0 +1,182 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numpy as np
import scipy.signal
def compute_vad_mask(x, fs, stop_db=-70):
frame_length = (fs + 49) // 50
x = x[: frame_length * (len(x) // frame_length)]
frames = x.reshape(-1, frame_length)
frame_energy = np.sum(frames ** 2, axis=1)
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
vactive = np.ones_like(frames)
vactive[frame_energy_smooth < max_threshold, :] = 0
vactive = vactive.reshape(-1)
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
filter = filter / filter.sum()
mask = np.convolve(vactive, filter, mode='same')
return x, mask
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
num_samples = frame_size + (num_frames - 1) * hop_size
if len(mask) < num_samples:
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
else:
mask = mask[:num_samples]
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
return new_mask
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
num_spectra = (len(x) - window_size - hop_size) // hop_size
window = scipy.signal.get_window(window, window_size)
N = window_size // 2
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
return psd
def frequency_mask(num_bands, up_factor, down_factor):
up_mask = np.zeros((num_bands, num_bands))
down_mask = np.zeros((num_bands, num_bands))
for i in range(num_bands):
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
return down_mask @ up_mask
def rect_fb(band_limits, num_bins=None):
num_bands = len(band_limits) - 1
if num_bins is None:
num_bins = band_limits[-1]
fb = np.zeros((num_bands, num_bins))
for i in range(num_bands):
fb[i, band_limits[i]:band_limits[i+1]] = 1
return fb
def compare(x, y, apply_vad=False):
""" Modified version of opus_compare for 16 kHz mono signals
Args:
x (np.ndarray): reference input signal scaled to [-1, 1]
y (np.ndarray): test signal scaled to [-1, 1]
Returns:
float: perceptually weighted error
"""
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
num_bands = len(band_limits) - 1
fb = rect_fb(band_limits, num_bins=81)
# trim samples to same size
num_samples = min(len(x), len(y))
x = x[:num_samples] * 2**15
y = y[:num_samples] * 2**15
psd_x = power_spectrum(x) + 100000
psd_y = power_spectrum(y) + 100000
num_frames = psd_x.shape[0]
# average band energies
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
# frequecy masking
f_mask = frequency_mask(num_bands, 0.1, 0.03)
mask_x = be_x @ f_mask.T
# temporal masking
for i in range(1, num_frames):
mask_x[i, :] += 0.5 * mask_x[i-1, :]
# apply mask
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
# 2-frame average
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
# distortion metric
re = masked_psd_y / masked_psd_x
im = np.log(re) ** 2
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
Ef = np.mean(Eb , axis=1)
if apply_vad:
_, mask = compute_vad_mask(x, 16000)
mask = convert_mask(mask, Ef.shape[0])
else:
mask = np.ones_like(Ef)
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
return float(err)
if __name__ == "__main__":
import argparse
from scipy.io import wavfile
parser = argparse.ArgumentParser()
parser.add_argument('ref', type=str, help='reference wav file')
parser.add_argument('deg', type=str, help='degraded wav file')
parser.add_argument('--apply-vad', action='store_true')
args = parser.parse_args()
fs1, x = wavfile.read(args.ref)
fs2, y = wavfile.read(args.deg)
if max(fs1, fs2) != 16000:
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
x = x.astype(np.float32) / 2**15
y = y.astype(np.float32) / 2**15
err = compare(x, y, apply_vad=args.apply_vad)
print(f"MOC: {err}")

View File

@@ -0,0 +1,190 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import numpy as np
import scipy.signal
def compute_vad_mask(x, fs, stop_db=-70):
frame_length = (fs + 49) // 50
x = x[: frame_length * (len(x) // frame_length)]
frames = x.reshape(-1, frame_length)
frame_energy = np.sum(frames ** 2, axis=1)
frame_energy_smooth = np.convolve(frame_energy, np.ones(5) / 5, mode='same')
max_threshold = frame_energy.max() * 10 ** (stop_db/20)
vactive = np.ones_like(frames)
vactive[frame_energy_smooth < max_threshold, :] = 0
vactive = vactive.reshape(-1)
filter = np.sin(np.arange(frame_length) * np.pi / (frame_length - 1))
filter = filter / filter.sum()
mask = np.convolve(vactive, filter, mode='same')
return x, mask
def convert_mask(mask, num_frames, frame_size=160, hop_size=40):
num_samples = frame_size + (num_frames - 1) * hop_size
if len(mask) < num_samples:
mask = np.concatenate((mask, np.zeros(num_samples - len(mask))), dtype=mask.dtype)
else:
mask = mask[:num_samples]
new_mask = np.array([np.mean(mask[i*hop_size : i*hop_size + frame_size]) for i in range(num_frames)])
return new_mask
def power_spectrum(x, window_size=160, hop_size=40, window='hamming'):
num_spectra = (len(x) - window_size - hop_size) // hop_size
window = scipy.signal.get_window(window, window_size)
N = window_size // 2
frames = np.concatenate([x[np.newaxis, i * hop_size : i * hop_size + window_size] for i in range(num_spectra)]) * window
psd = np.abs(np.fft.fft(frames, axis=1)[:, :N + 1]) ** 2
return psd
def frequency_mask(num_bands, up_factor, down_factor):
up_mask = np.zeros((num_bands, num_bands))
down_mask = np.zeros((num_bands, num_bands))
for i in range(num_bands):
up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
return down_mask @ up_mask
def rect_fb(band_limits, num_bins=None):
num_bands = len(band_limits) - 1
if num_bins is None:
num_bins = band_limits[-1]
fb = np.zeros((num_bands, num_bins))
for i in range(num_bands):
fb[i, band_limits[i]:band_limits[i+1]] = 1
return fb
def _compare(x, y, apply_vad=False, factor=1):
""" Modified version of opus_compare for 16 kHz mono signals
Args:
x (np.ndarray): reference input signal scaled to [-1, 1]
y (np.ndarray): test signal scaled to [-1, 1]
Returns:
float: perceptually weighted error
"""
# filter bank: bark scale with minimum-2-bin bands and cutoff at 7.5 kHz
band_limits = [factor * b for b in [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]]
window_size = factor * 160
hop_size = factor * 40
num_bins = window_size // 2 + 1
num_bands = len(band_limits) - 1
fb = rect_fb(band_limits, num_bins=num_bins)
# trim samples to same size
num_samples = min(len(x), len(y))
x = x[:num_samples].copy() * 2**15
y = y[:num_samples].copy() * 2**15
psd_x = power_spectrum(x, window_size=window_size, hop_size=hop_size) + 100000
psd_y = power_spectrum(y, window_size=window_size, hop_size=hop_size) + 100000
num_frames = psd_x.shape[0]
# average band energies
be_x = (psd_x @ fb.T) / np.sum(fb, axis=1)
# frequecy masking
f_mask = frequency_mask(num_bands, 0.1, 0.03)
mask_x = be_x @ f_mask.T
# temporal masking
for i in range(1, num_frames):
mask_x[i, :] += (0.5 ** factor) * mask_x[i-1, :]
# apply mask
masked_psd_x = psd_x + 0.1 * (mask_x @ fb)
masked_psd_y = psd_y + 0.1 * (mask_x @ fb)
# 2-frame average
masked_psd_x = masked_psd_x[1:] + masked_psd_x[:-1]
masked_psd_y = masked_psd_y[1:] + masked_psd_y[:-1]
# distortion metric
re = masked_psd_y / masked_psd_x
#im = re - np.log(re) - 1
im = np.log(re) ** 2
Eb = ((im @ fb.T) / np.sum(fb, axis=1))
Ef = np.mean(Eb ** 1, axis=1)
if apply_vad:
_, mask = compute_vad_mask(x, 16000)
mask = convert_mask(mask, Ef.shape[0])
else:
mask = np.ones_like(Ef)
err = np.mean(np.abs(Ef[mask > 1e-6]) ** 3) ** (1/6)
return float(err)
def compare(x, y, apply_vad=False):
err = np.linalg.norm([_compare(x, y, apply_vad=apply_vad, factor=1)], ord=2)
return err
if __name__ == "__main__":
import argparse
from scipy.io import wavfile
parser = argparse.ArgumentParser()
parser.add_argument('ref', type=str, help='reference wav file')
parser.add_argument('deg', type=str, help='degraded wav file')
parser.add_argument('--apply-vad', action='store_true')
args = parser.parse_args()
fs1, x = wavfile.read(args.ref)
fs2, y = wavfile.read(args.deg)
if max(fs1, fs2) != 16000:
raise ValueError('error: encountered sampling frequency diffrent from 16kHz')
x = x.astype(np.float32) / 2**15
y = y.astype(np.float32) / 2**15
err = compare(x, y, apply_vad=args.apply_vad)
print(f"MOC: {err}")

View File

@@ -0,0 +1,98 @@
#!/bin/bash
if [ ! -f "$PYTHON" ]
then
echo "PYTHON variable does not link to a file. Please point it to your python executable."
exit 1
fi
if [ ! -f "$TESTMODEL" ]
then
echo "TESTMODEL variable does not link to a file. Please point it to your copy of test_model.py"
exit 1
fi
if [ ! -f "$OPUSDEMO" ]
then
echo "OPUSDEMO variable does not link to a file. Please point it to your patched version of opus_demo."
exit 1
fi
if [ ! -f "$LACE" ]
then
echo "LACE variable does not link to a file. Please point it to your copy of the LACE checkpoint."
exit 1
fi
if [ ! -f "$NOLACE" ]
then
echo "LACE variable does not link to a file. Please point it to your copy of the NOLACE checkpoint."
exit 1
fi
case $# in
2) INPUT=$1; OUTPUT=$2;;
*) echo "process_dataset.sh <input folder> <output folder>"; exit 1;;
esac
if [ -d $OUTPUT ]
then
echo "output folder $OUTPUT exists, aborting..."
exit 1
fi
mkdir -p $OUTPUT
if [ "$BITRATES" == "" ]
then
BITRATES=( 6000 7500 9000 12000 15000 18000 24000 32000 )
echo "BITRATES variable not defined. Proceeding with default bitrates ${BITRATES[@]}."
fi
echo "LACE=${LACE}" > ${OUTPUT}/info.txt
echo "NOLACE=${NOLACE}" >> ${OUTPUT}/info.txt
ITEMFILE=${OUTPUT}/items.txt
BITRATEFILE=${OUTPUT}/bitrates.txt
FPROCESSING=${OUTPUT}/processing
FCLEAN=${OUTPUT}/clean
FOPUS=${OUTPUT}/opus
FLACE=${OUTPUT}/lace
FNOLACE=${OUTPUT}/nolace
mkdir -p $FPROCESSING $FCLEAN $FOPUS $FLACE $FNOLACE
echo "${BITRATES[@]}" > $BITRATEFILE
for fn in $(find $INPUT -type f -name "*.wav")
do
UUID=$(uuid)
echo "$UUID $fn" >> $ITEMFILE
PIDS=( )
for br in ${BITRATES[@]}
do
# run opus
pfolder=${FPROCESSING}/${UUID}_${br}
mkdir -p $pfolder
sox $fn -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16
(cd ${pfolder} && $OPUSDEMO voip 16000 1 $br clean.s16 noisy.s16)
# copy clean and opus
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/clean.s16 $FCLEAN/${UUID}_${br}_clean.wav
sox -c 1 -r 16000 -b 16 -e signed-integer $pfolder/noisy.s16 $FOPUS/${UUID}_${br}_opus.wav
# run LACE
$PYTHON $TESTMODEL $pfolder $LACE $FLACE/${UUID}_${br}_lace.wav &
PIDS+=( "$!" )
# run NoLACE
$PYTHON $TESTMODEL $pfolder $NOLACE $FNOLACE/${UUID}_${br}_nolace.wav &
PIDS+=( "$!" )
done
for pid in ${PIDS[@]}
do
wait $pid
done
done

View File

@@ -0,0 +1,138 @@
"""
/* Copyright (c) 2023 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
import os
import argparse
import tempfile
import shutil
import pandas as pd
from scipy.spatial.distance import cdist
from scipy.io import wavfile
import numpy as np
from nomad_audio.nomad import Nomad
parser = argparse.ArgumentParser()
parser.add_argument('folder', type=str, help='folder with processed items')
parser.add_argument('--full-reference', action='store_true', help='use NOMAD as full-reference metric')
parser.add_argument('--device', type=str, default=None, help='device for Nomad')
def get_bitrates(folder):
with open(os.path.join(folder, 'bitrates.txt')) as f:
x = f.read()
bitrates = [int(y) for y in x.rstrip('\n').split()]
return bitrates
def get_itemlist(folder):
with open(os.path.join(folder, 'items.txt')) as f:
lines = f.readlines()
items = [x.split()[0] for x in lines]
return items
def nomad_wrapper(ref_folder, deg_folder, full_reference=False, ref_embeddings=None, device=None):
model = Nomad(device=device)
if not full_reference:
results = model.predict(nmr=ref_folder, deg=deg_folder)[0].to_dict()['NOMAD']
return results, None
else:
if ref_embeddings is None:
print(f"Computing reference embeddings from {ref_folder}")
ref_data = pd.DataFrame(sorted(os.listdir(ref_folder)))
ref_data.columns = ['filename']
ref_data['filename'] = [os.path.join(ref_folder, x) for x in ref_data['filename']]
ref_embeddings = model.get_embeddings_csv(model.model, ref_data).set_index('filename')
print(f"Computing degraded embeddings from {deg_folder}")
deg_data = pd.DataFrame(sorted(os.listdir(deg_folder)))
deg_data.columns = ['filename']
deg_data['filename'] = [os.path.join(deg_folder, x) for x in deg_data['filename']]
deg_embeddings = model.get_embeddings_csv(model.model, deg_data).set_index('filename')
dist = np.diag(cdist(ref_embeddings, deg_embeddings)) # wasteful
test_files = [x.split('/')[-1].split('.')[0] for x in deg_embeddings.index]
results = dict(zip(test_files, dist))
return results, ref_embeddings
def nomad_process_all(folder, full_reference=False, device=None):
bitrates = get_bitrates(folder)
items = get_itemlist(folder)
with tempfile.TemporaryDirectory() as dir:
cleandir = os.path.join(dir, 'clean')
opusdir = os.path.join(dir, 'opus')
lacedir = os.path.join(dir, 'lace')
nolacedir = os.path.join(dir, 'nolace')
# prepare files
for d in [cleandir, opusdir, lacedir, nolacedir]: os.makedirs(d)
for br in bitrates:
for item in items:
for cond in ['clean', 'opus', 'lace', 'nolace']:
shutil.copyfile(os.path.join(folder, cond, f"{item}_{br}_{cond}.wav"), os.path.join(dir, cond, f"{item}_{br}.wav"))
nomad_opus, ref_embeddings = nomad_wrapper(cleandir, opusdir, full_reference=full_reference, ref_embeddings=None)
nomad_lace, ref_embeddings = nomad_wrapper(cleandir, lacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
nomad_nolace, ref_embeddings = nomad_wrapper(cleandir, nolacedir, full_reference=full_reference, ref_embeddings=ref_embeddings)
results = dict()
for br in bitrates:
results[br] = np.zeros((len(items), 3))
for i, item in enumerate(items):
key = f"{item}_{br}"
results[br][i, 0] = nomad_opus[key]
results[br][i, 1] = nomad_lace[key]
results[br][i, 2] = nomad_nolace[key]
return results
if __name__ == "__main__":
args = parser.parse_args()
items = get_itemlist(args.folder)
bitrates = get_bitrates(args.folder)
results = nomad_process_all(args.folder, full_reference=args.full_reference, device=args.device)
np.save(os.path.join(args.folder, f'results_nomad.npy'), results)
print("Done.")

View File

@@ -0,0 +1,196 @@
import os
import argparse
import yaml
import subprocess
import numpy as np
from moc2 import compare as moc
DEBUG=False
parser = argparse.ArgumentParser()
parser.add_argument('inputdir', type=str, help='Input folder with test items')
parser.add_argument('outputdir', type=str, help='Output folder')
parser.add_argument('bitrate', type=int, help='bitrate to test')
parser.add_argument('--reference_opus_demo', type=str, default='./opus_demo', help='reference opus_demo binary for generating bitstreams and reference output')
parser.add_argument('--encoder_options', type=str, default="", help='encoder options (e.g. -complexity 5)')
parser.add_argument('--test_opus_demo', type=str, default='./opus_demo', help='opus_demo binary under test')
parser.add_argument('--test_opus_demo_options', type=str, default='-dec_complexity 7', help='options for test opus_demo (e.g. "-dec_complexity 7")')
parser.add_argument('--verbose', type=int, default=0, help='verbosity level: 0 for quiet (default), 1 for reporting individual test results, 2 for reporting per-item scores in failed tests')
def run_opus_encoder(opus_demo_path, input_pcm_path, bitstream_path, application, fs, num_channels, bitrate, options=[], verbose=False):
call_args = [
opus_demo_path,
"-e",
application,
str(fs),
str(num_channels),
str(bitrate),
"-bandwidth",
"WB"
]
call_args += options
call_args += [
input_pcm_path,
bitstream_path
]
try:
if verbose:
print(f"running {call_args}...")
subprocess.run(call_args)
else:
subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except:
return 1
return 0
def run_opus_decoder(opus_demo_path, bitstream_path, output_pcm_path, fs, num_channels, options=[], verbose=False):
call_args = [
opus_demo_path,
"-d",
str(fs),
str(num_channels)
]
call_args += options
call_args += [
bitstream_path,
output_pcm_path
]
try:
if verbose:
print(f"running {call_args}...")
subprocess.run(call_args)
else:
subprocess.run(call_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except:
return 1
return 0
def compute_moc_score(reference_pcm, test_pcm, delay=91):
x_ref = np.fromfile(reference_pcm, dtype=np.int16).astype(np.float32) / (2 ** 15)
x_cut = np.fromfile(test_pcm, dtype=np.int16).astype(np.float32) / (2 ** 15)
moc_score = moc(x_ref, x_cut[delay:])
return moc_score
def sox(*call_args):
try:
call_args = ["sox"] + list(call_args)
subprocess.run(call_args)
return 0
except:
return 1
def process_clip_factory(ref_opus_demo, test_opus_demo, enc_options, test_options):
def process_clip(clip_path, processdir, bitrate):
# derive paths
clipname = os.path.splitext(os.path.split(clip_path)[1])[0]
pcm_path = os.path.join(processdir, clipname + ".raw")
bitstream_path = os.path.join(processdir, clipname + ".bin")
ref_path = os.path.join(processdir, clipname + "_ref.raw")
test_path = os.path.join(processdir, clipname + "_test.raw")
# run sox
sox(clip_path, pcm_path)
# run encoder
run_opus_encoder(ref_opus_demo, pcm_path, bitstream_path, "voip", 16000, 1, bitrate, enc_options)
# run decoder
run_opus_decoder(ref_opus_demo, bitstream_path, ref_path, 16000, 1)
run_opus_decoder(test_opus_demo, bitstream_path, test_path, 16000, 1, options=test_options)
d_ref = compute_moc_score(pcm_path, ref_path)
d_test = compute_moc_score(pcm_path, test_path)
return d_ref, d_test
return process_clip
def main(inputdir, outputdir, bitrate, reference_opus_demo, test_opus_demo, enc_option_string, test_option_string, verbose):
# load clips list
with open(os.path.join(inputdir, 'clips.yml'), "r") as f:
clips = yaml.safe_load(f)
# parse test options
enc_options = enc_option_string.split()
test_options = test_option_string.split()
process_clip = process_clip_factory(reference_opus_demo, test_opus_demo, enc_options, test_options)
os.makedirs(outputdir, exist_ok=True)
processdir = os.path.join(outputdir, 'process')
os.makedirs(processdir, exist_ok=True)
num_passed = 0
results = dict()
min_rel_diff = 1000
min_mean = 1000
worst_clip = None
worst_lang = None
for lang, lang_clips in clips.items():
if verbose > 0: print(f"processing language {lang}...")
results[lang] = np.zeros((len(lang_clips), 2))
for i, clip in enumerate(lang_clips):
clip_path = os.path.join(inputdir, clip)
d_ref, d_test = process_clip(clip_path, processdir, bitrate)
results[lang][i, 0] = d_ref
results[lang][i, 1] = d_test
alpha = 0.5
rel_diff = ((results[lang][:, 0] ** alpha - results[lang][:, 1] ** alpha) /(results[lang][:, 0] ** alpha))
min_idx = np.argmin(rel_diff).item()
if rel_diff[min_idx] < min_rel_diff:
min_rel_diff = rel_diff[min_idx]
worst_clip = lang_clips[min_idx]
if np.mean(rel_diff) < min_mean:
min_mean = np.mean(rel_diff).item()
worst_lang = lang
if np.min(rel_diff) < -0.1 or np.mean(rel_diff) < -0.025:
if verbose > 0: print(f"FAIL ({np.mean(results[lang], axis=0)} {np.mean(rel_diff)} {np.min(rel_diff)})")
if verbose > 1:
for i, c in enumerate(lang_clips):
print(f" {c:50s} {results[lang][i]} {rel_diff[i]}")
else:
if verbose > 0: print(f"PASS ({np.mean(results[lang], axis=0)} {np.mean(rel_diff)} {np.min(rel_diff)})")
num_passed += 1
print(f"{num_passed}/{len(clips)} tests passed!")
print(f"worst case occured at clip {worst_clip} with relative difference of {min_rel_diff}")
print(f"worst mean relative difference was {min_mean} for test {worst_lang}")
np.save(os.path.join(outputdir, f'results_' + "_".join(test_options) + f"_{bitrate}.npy"), results, allow_pickle=True)
if __name__ == "__main__":
args = parser.parse_args()
main(args.inputdir,
args.outputdir,
args.bitrate,
args.reference_opus_demo,
args.test_opus_demo,
args.encoder_options,
args.test_opus_demo_options,
args.verbose)

View File

@@ -0,0 +1,205 @@
""" module for inspecting models during inference """
import os
import yaml
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import numpy as np
# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
_state = dict()
_folder = 'endoscopy'
def get_gru_gates(gru, input, state):
hidden_size = gru.hidden_size
direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
# reset gate
start, stop = 0 * hidden_size, 1 * hidden_size
reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# update gate
start, stop = 1 * hidden_size, 2 * hidden_size
update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
# new gate
start, stop = 2 * hidden_size, 3 * hidden_size
new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
def init(folder='endoscopy'):
""" sets up output folder for endoscopy data """
global _folder
_folder = folder
if not os.path.exists(folder):
os.makedirs(folder)
else:
print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
def write_data(key, data, fs):
""" appends data to previous data written under key """
global _state
# convert to numpy if torch.Tensor is given
if isinstance(data, torch.Tensor):
data = data.detach().numpy()
if not key in _state:
_state[key] = {
'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
'fs' : fs,
'dim' : tuple(data.shape),
'dtype' : str(data.dtype)
}
with open(os.path.join(_folder, key + '.yml'), 'w') as f:
f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
else:
if _state[key]['fs'] != fs:
raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
if _state[key]['dtype'] != str(data.dtype):
raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
if _state[key]['dim'] != tuple(data.shape):
raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
_state[key]['fid'].write(data.tobytes())
def close(folder='endoscopy'):
""" clean up """
for key in _state.keys():
_state[key]['fid'].close()
def read_data(folder='endoscopy'):
""" retrieves written data as numpy arrays """
keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
return_dict = dict()
for key in keys:
with open(os.path.join(folder, key + '.yml'), 'r') as f:
value = yaml.load(f.read(), yaml.FullLoader)
with open(os.path.join(folder, key + '.bin'), 'rb') as f:
data = np.frombuffer(f.read(), dtype=value['dtype'])
value['data'] = data.reshape((-1,) + value['dim'])
return_dict[key] = value
return return_dict
def get_best_reshape(shape, target_ratio=1):
""" calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return (1,)
num_columns = int((pixel_count / target_ratio)**.5)
while (pixel_count % num_columns):
num_columns -= 1
num_rows = pixel_count // num_columns
return (num_rows, num_columns)
def get_type_and_shape(shape):
# can happen if data is one dimensional
if len(shape) == 0:
shape = (1,)
# calculate pixel count
if len(shape) > 1:
pixel_count = 1
for s in shape:
pixel_count *= s
else:
pixel_count = shape[0]
if pixel_count == 1:
return 'plot', (1, )
# stay with shape if already 2-dimensional
if len(shape) == 2:
if (shape[0] != pixel_count) or (shape[1] != pixel_count):
return 'image', shape
return 'image', get_best_reshape(shape)
def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
# determine plot setup
num_keys = len(data.keys())
num_rows = int((num_keys * 3/4) ** .5)
num_cols = (num_keys + num_rows - 1) // num_rows
fig, axs = plt.subplots(num_rows, num_cols)
fig.set_size_inches(num_cols * 5, num_rows * 5)
display = dict()
fs_max = max([val['fs'] for val in data.values()])
num_samples = max([val['data'].shape[0] for val in data.values()])
keys = sorted(data.keys())
# inspect data
for i, key in enumerate(keys):
axs[i // num_cols, i % num_cols].title.set_text(key)
display[key] = dict()
display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
display[key]['down_factor'] = data[key]['fs'] / fs_max
start_index = max(start_index, half_signal_window_length)
while stop_index < 0:
stop_index += num_samples
stop_index = min(stop_index, num_samples - half_signal_window_length)
# actual plotting
frames = []
for index in range(start_index, stop_index):
ims = []
for i, key in enumerate(keys):
feature_index = int(round(index * display[key]['down_factor']))
if display[key]['type'] == 'plot':
ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
elif display[key]['type'] == 'image':
ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
frames.append(ims)
ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
if not filename.endswith('.mp4'):
filename += '.mp4'
ani.save(filename)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,25 @@
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation
def make_playback_animation(savepath, spec, duration_ms, vmin=20, vmax=90):
fig, axs = plt.subplots()
axs.set_axis_off()
fig.set_size_inches((duration_ms / 1000 * 5, 5))
frames = []
frame_duration=20
num_frames = int(duration_ms / frame_duration + .99)
spec_height, spec_width = spec.shape
for i in range(num_frames):
xpos = (i - 1) / (num_frames - 3) * (spec_width - 1)
new_frame = axs.imshow(spec, cmap='inferno', origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
if i in {0, num_frames - 1}:
frames.append([new_frame])
else:
line = axs.plot([xpos, xpos], [0, spec_height-1], color='white', alpha=0.8)[0]
frames.append([new_frame, line])
ani = matplotlib.animation.ArtistAnimation(fig, frames, blit=True, interval=frame_duration)
ani.save(savepath, dpi=720)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long