add some code

This commit is contained in:
2025-09-05 13:25:11 +08:00
parent 9ff0a99e7a
commit 3cf1229a85
8911 changed files with 2535396 additions and 0 deletions

View File

@@ -0,0 +1,54 @@
# Framewise Auto-Regressive GAN (FARGAN)
Implementation of FARGAN, a low-complexity neural vocoder. Pre-trained models
are provided as C code in the dnn/ directory with the corresponding model in
dnn/models/ directory (name starts with fargan_). If you don't want to train
a new FARGAN model, you can skip straight to the Inference section.
## Data preparation
For data preparation you need to build Opus as detailed in the top-level README.
You will need to use the --enable-deep-plc configure option.
The build will produce an executable named "dump_data".
To prepare the training data, run:
```
./dump_data -train in_speech.pcm out_features.f32 out_speech.pcm
```
Where the in_speech.pcm speech file is a raw 16-bit PCM file sampled at 16 kHz.
The speech data used for training the model can be found at:
https://media.xiph.org/lpcnet/speech/tts_speech_negative_16k.sw
## Training
To perform pre-training, run the following command:
```
python ./train_fargan.py out_features.f32 out_speech.pcm output_dir --epochs 400 --batch-size 4096 --lr 0.002 --cuda-visible-devices 0
```
Once pre-training is complete, run adversarial training using:
```
python adv_train_fargan.py out_features.f32 out_speech.pcm output_dir --lr 0.000002 --reg-weight 5 --batch-size 160 --cuda-visible-devices 0 --initial-checkpoint output_dir/checkpoints/fargan_400.pth
```
The final model will be in output_dir/checkpoints/fargan_adv_50.pth.
The model can optionally be converted to C using:
```
python dump_fargan_weights.py output_dir/checkpoints/fargan_adv_50.pth fargan_c_dir
```
which will create a fargan_data.c and a fargan_data.h file in the fargan_c_dir directory.
Copy these files to the opus/dnn/ directory (replacing the existing ones) and recompile Opus.
## Inference
To run the inference, start by generating the features from the audio using:
```
./fargan_demo -features test_speech.pcm test_features.f32
```
Synthesis can be achieved either using the PyTorch code or the C code.
To synthesize from PyTorch, run:
```
python test_fargan.py output_dir/checkpoints/fargan_adv_50.pth test_features.f32 output_speech.pcm
```
To synthesize from the C code, run:
```
./fargan_demo -fargan-synthesis test_features.f32 output_speech.pcm
```

View File

@@ -0,0 +1,278 @@
import os
import argparse
import random
import numpy as np
import sys
import math as m
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
import fargan
from dataset import FARGANDataset
from stft_loss import *
source_dir = os.path.split(os.path.abspath(__file__))[0]
sys.path.append(os.path.join(source_dir, "../osce/"))
import models as osce_models
def fmap_loss(scores_real, scores_gen):
num_discs = len(scores_real)
loss_feat = 0
for k in range(num_discs):
num_layers = len(scores_gen[k]) - 1
f = 4 / num_discs / num_layers
for l in range(num_layers):
loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach())
return loss_feat
parser = argparse.ArgumentParser()
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
parser.add_argument('output', type=str, help='path to output folder')
parser.add_argument('--suffix', type=str, help="model name suffix", default="")
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128)
training_group.add_argument('--lr', type=float, help='learning rate, default: 5e-4', default=5e-4)
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 50', default=50)
training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 60', default=60)
training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 0.0', default=0.0)
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
training_group.add_argument('--reg-weight', type=float, help='regression loss weight, default: 1.0', default=1.0)
training_group.add_argument('--fmap-weight', type=float, help='feature matchin loss weight, default: 1.0', default=1.)
args = parser.parse_args()
if args.cuda_visible_devices != None:
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
# checkpoints
checkpoint_dir = os.path.join(args.output, 'checkpoints')
checkpoint = dict()
os.makedirs(checkpoint_dir, exist_ok=True)
# training parameters
batch_size = args.batch_size
lr = args.lr
epochs = args.epochs
sequence_length = args.sequence_length
lr_decay = args.lr_decay
adam_betas = [0.8, 0.99]
adam_eps = 1e-8
features_file = args.features
signal_file = args.signal
# model parameters
cond_size = args.cond_size
checkpoint['batch_size'] = batch_size
checkpoint['lr'] = lr
checkpoint['lr_decay'] = lr_decay
checkpoint['epochs'] = epochs
checkpoint['sequence_length'] = sequence_length
checkpoint['adam_betas'] = adam_betas
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint['model_args'] = ()
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant}
print(checkpoint['model_kwargs'])
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
#discriminator
disc_name = 'fdmresdisc'
disc = osce_models.model_dict[disc_name](
architecture='free',
design='f_down',
fft_sizes_16k=[2**n for n in range(6, 12)],
freq_roi=[0, 7400],
max_channels=256,
noise_gain=0.0
)
if type(args.initial_checkpoint) != type(None):
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
checkpoint['state_dict'] = model.state_dict()
dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
optimizer_disc = torch.optim.AdamW([p for p in disc.parameters() if p.requires_grad], lr=lr, betas=adam_betas, eps=adam_eps)
# learning rate scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
scheduler_disc = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer_disc, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
states = None
spect_loss = MultiResolutionSTFTLoss(device).to(device)
for param in model.parameters():
param.requires_grad = False
batch_count = 0
if __name__ == '__main__':
model.to(device)
disc.to(device)
for epoch in range(1, epochs + 1):
m_r = 0
m_f = 0
s_r = 1
s_f = 1
running_cont_loss = 0
running_disc_loss = 0
running_gen_loss = 0
running_fmap_loss = 0
running_reg_loss = 0
running_wc = 0
print(f"training epoch {epoch}...")
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (features, periods, target, lpc) in enumerate(tepoch):
if epoch == 1 and i == 400:
for param in model.parameters():
param.requires_grad = True
for param in model.cond_net.parameters():
param.requires_grad = False
for param in model.sig_net.cond_gain_dense.parameters():
param.requires_grad = False
optimizer.zero_grad()
features = features.to(device)
#lpc = lpc.to(device)
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
#lpc = fargan.interp_lpc(lpc, 4)
periods = periods.to(device)
if True:
target = target[:, :sequence_length*160]
#lpc = lpc[:,:sequence_length*4,:]
features = features[:,:sequence_length+4,:]
periods = periods[:,:sequence_length+4]
else:
target=target[::2, :]
#lpc=lpc[::2,:]
features=features[::2,:]
periods=periods[::2,:]
target = target.to(device)
#target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
#nb_pre = random.randrange(1, 6)
nb_pre = 2
pre = target[:, :nb_pre*160]
output, _ = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
output = torch.cat([pre, output], -1)
# discriminator update
scores_gen = disc(output.detach().unsqueeze(1))
scores_real = disc(target.unsqueeze(1))
disc_loss = 0
for scale in scores_gen:
disc_loss += ((scale[-1]) ** 2).mean()
m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item()
s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item()
for scale in scores_real:
disc_loss += ((1 - scale[-1]) ** 2).mean()
m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item()
s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item()
disc_loss = 0.5 * disc_loss / len(scores_gen)
winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) )
running_wc += winning_chance
disc.zero_grad()
disc_loss.backward()
optimizer_disc.step()
# model update
scores_gen = disc(output.unsqueeze(1))
if False: # todo: check whether that makes a difference
with torch.no_grad():
scores_real = disc(target.unsqueeze(1))
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], output[:, nb_pre*160:nb_pre*160+80])
specc_loss = spect_loss(output, target.detach())
reg_loss = (.00*cont_loss + specc_loss)
loss_gen = 0
for scale in scores_gen:
loss_gen += ((1 - scale[-1]) ** 2).mean() / len(scores_gen)
feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen)
reg_weight = args.reg_weight# + 15./(1 + (batch_count/7600.))
gen_loss = reg_weight * reg_loss + feat_loss + loss_gen
model.zero_grad()
gen_loss.backward()
optimizer.step()
#model.clip_weights()
scheduler.step()
scheduler_disc.step()
running_cont_loss += cont_loss.detach().cpu().item()
running_gen_loss += loss_gen.detach().cpu().item()
running_disc_loss += disc_loss.detach().cpu().item()
running_fmap_loss += feat_loss.detach().cpu().item()
running_reg_loss += reg_loss.detach().cpu().item()
tepoch.set_postfix(cont_loss=f"{running_cont_loss/(i+1):8.5f}",
reg_weight=f"{reg_weight:8.5f}",
gen_loss=f"{running_gen_loss/(i+1):8.5f}",
disc_loss=f"{running_disc_loss/(i+1):8.5f}",
fmap_loss=f"{running_fmap_loss/(i+1):8.5f}",
reg_loss=f"{running_reg_loss/(i+1):8.5f}",
wc = f"{running_wc/(i+1):8.5f}",
)
batch_count = batch_count + 1
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_adv_{epoch}.pth')
checkpoint['state_dict'] = model.state_dict()
checkpoint['disc_sate_dict'] = disc.state_dict()
checkpoint['loss'] = {
'cont': running_cont_loss / len(dataloader),
'gen': running_gen_loss / len(dataloader),
'disc': running_disc_loss / len(dataloader),
'fmap': running_fmap_loss / len(dataloader),
'reg': running_reg_loss / len(dataloader)
}
checkpoint['epoch'] = epoch
torch.save(checkpoint, checkpoint_path)

View File

@@ -0,0 +1,61 @@
import torch
import numpy as np
import fargan
class FARGANDataset(torch.utils.data.Dataset):
def __init__(self,
feature_file,
signal_file,
frame_size=160,
sequence_length=15,
lookahead=1,
nb_used_features=20,
nb_features=36):
self.frame_size = frame_size
self.sequence_length = sequence_length
self.lookahead = lookahead
self.nb_features = nb_features
self.nb_used_features = nb_used_features
pcm_chunk_size = self.frame_size*self.sequence_length
self.data = np.memmap(signal_file, dtype='int16', mode='r')
#self.data = self.data[1::2]
self.nb_sequences = len(self.data)//(pcm_chunk_size)-4
self.data = self.data[(4-self.lookahead)*self.frame_size:]
self.data = self.data[:self.nb_sequences*pcm_chunk_size]
#self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size))
sizeof = self.data.strides[-1]
self.data = np.lib.stride_tricks.as_strided(self.data, shape=(self.nb_sequences, pcm_chunk_size*2),
strides=(pcm_chunk_size*sizeof, sizeof))
self.features = np.reshape(np.memmap(feature_file, dtype='float32', mode='r'), (-1, nb_features))
sizeof = self.features.strides[-1]
self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length*2+4, nb_features),
strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
#self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
self.periods = np.round(np.clip(256./2**(self.features[:,:,self.nb_used_features-2]+1.5), 32, 255)).astype('int')
self.lpc = self.features[:, :, self.nb_used_features:]
self.features = self.features[:, :, :self.nb_used_features]
print("lpc_size:", self.lpc.shape)
def __len__(self):
return self.nb_sequences
def __getitem__(self, index):
features = self.features[index, :, :].copy()
if self.lookahead != 0:
lpc = self.lpc[index, 4-self.lookahead:-self.lookahead, :].copy()
else:
lpc = self.lpc[index, 4:, :].copy()
data = self.data[index, :].copy().astype(np.float32) / 2**15
periods = self.periods[index, :].copy()
#lpc = lpc*(self.gamma**np.arange(1,17))
#lpc=lpc[None,:,:]
#lpc = fargan.interp_lpc(lpc, 4)
#lpc=lpc[0,:,:]
return features, periods, data, lpc

View File

@@ -0,0 +1,112 @@
import os
import sys
import argparse
import torch
from torch import nn
sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange'))
import wexchange.torch
import fargan
#from models import model_dict
unquantized = [ 'cond_net.pembed', 'cond_net.fdense1', 'sig_net.cond_gain_dense', 'sig_net.gain_dense_out' ]
unquantized2 = [
'cond_net.pembed',
'cond_net.fdense1',
'cond_net.fconv1',
'cond_net.fconv2',
'cont_net.0',
'sig_net.cond_gain_dense',
'sig_net.fwc0.conv',
'sig_net.fwc0.glu.gate',
'sig_net.dense1_glu.gate',
'sig_net.gru1_glu.gate',
'sig_net.gru2_glu.gate',
'sig_net.gru3_glu.gate',
'sig_net.skip_glu.gate',
'sig_net.skip_dense',
'sig_net.sig_dense_out',
'sig_net.gain_dense_out'
]
description=f"""
This is an unsafe dumping script for FARGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer
and will fail to export any other weights.
Furthermore, the quanitze option relies on the following explicit list of layers to be excluded:
{unquantized}.
Modify this script manually if adjustments are needed.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument('weightfile', type=str, help='weight file path')
parser.add_argument('export_folder', type=str)
parser.add_argument('--export-filename', type=str, default='fargan_data', help='filename for source and header file (.c and .h will be added), defaults to fargan_data')
parser.add_argument('--struct-name', type=str, default='FARGAN', help='name for C struct, defaults to FARGAN')
parser.add_argument('--quantize', action='store_true', help='apply quantization')
if __name__ == "__main__":
args = parser.parse_args()
print(f"loading weights from {args.weightfile}...")
saved_gen= torch.load(args.weightfile, map_location='cpu')
saved_gen['model_args'] = ()
saved_gen['model_kwargs'] = {'cond_size': 256, 'gamma': 0.9}
model = fargan.FARGAN(*saved_gen['model_args'], **saved_gen['model_kwargs'])
model.load_state_dict(saved_gen['state_dict'], strict=False)
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
model.apply(_remove_weight_norm)
print("dumping model...")
quantize_model=args.quantize
output_folder = args.export_folder
os.makedirs(output_folder, exist_ok=True)
writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name, add_typedef=True)
for name, module in model.named_modules():
if quantize_model:
quantize=name not in unquantized
scale = None if quantize else 1/128
else:
quantize=False
scale=1/128
if isinstance(module, nn.Linear):
print(f"dumping linear layer {name}...")
wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
elif isinstance(module, nn.Conv1d):
print(f"dumping conv1d layer {name}...")
wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
elif isinstance(module, nn.GRU):
print(f"dumping GRU layer {name}...")
wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
elif isinstance(module, nn.GRUCell):
print(f"dumping GRUCell layer {name}...")
wexchange.torch.dump_torch_grucell_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
elif isinstance(module, nn.Embedding):
print(f"dumping Embedding layer {name}...")
wexchange.torch.dump_torch_embedding_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
#wexchange.torch.dump_torch_embedding_weights(writer, module)
else:
print(f"Ignoring layer {name}...")
writer.close()

View File

@@ -0,0 +1,346 @@
import os
import sys
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import filters
from torch.nn.utils import weight_norm
#from convert_lsp import lpc_to_lsp, lsp_to_lpc
from rc import lpc2rc, rc2lpc
source_dir = os.path.split(os.path.abspath(__file__))[0]
sys.path.append(os.path.join(source_dir, "../dnntools"))
from dnntools.quantization import soft_quant
Fs = 16000
fid_dict = {}
def dump_signal(x, filename):
return
if filename in fid_dict:
fid = fid_dict[filename]
else:
fid = open(filename, "w")
fid_dict[filename] = fid
x = x.detach().numpy().astype('float32')
x.tofile(fid)
def sig_l1(y_true, y_pred):
return torch.mean(abs(y_true-y_pred))/torch.mean(abs(y_true))
def sig_loss(y_true, y_pred):
t = y_true/(1e-15+torch.norm(y_true, dim=-1, p=2, keepdim=True))
p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True))
return torch.mean(1.-torch.sum(p*t, dim=-1))
def interp_lpc(lpc, factor):
#print(lpc.shape)
#f = (np.arange(factor)+.5*((factor+1)%2))/factor
lsp = torch.atanh(lpc2rc(lpc))
#print("lsp0:")
#print(lsp)
shape = lsp.shape
#print("shape is", shape)
shape = (shape[0], shape[1]*factor, shape[2])
interp_lsp = torch.zeros(shape, device=lpc.device)
for k in range(factor):
f = (k+.5*((factor+1)%2))/factor
interp = (1-f)*lsp[:,:-1,:] + f*lsp[:,1:,:]
interp_lsp[:,factor//2+k:-(factor//2):factor,:] = interp
for k in range(factor//2):
interp_lsp[:,k,:] = interp_lsp[:,factor//2,:]
for k in range((factor+1)//2):
interp_lsp[:,-k-1,:] = interp_lsp[:,-(factor+3)//2,:]
#print("lsp:")
#print(interp_lsp)
return rc2lpc(torch.tanh(interp_lsp))
def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
device = x.device
batch_size = lpc.size(0)
nb_frames = lpc.shape[1]
sig = torch.zeros(batch_size, subframe_size+16, device=device)
x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size))
out = torch.zeros((batch_size, 0), device=device)
#if gamma is not None:
# bw = gamma**(torch.arange(1, 17, device=device))
# lpc = lpc*bw[None,None,:]
ones = torch.ones((*(lpc.shape[:-1]), 1), device=device)
zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device)
a = torch.cat([ones, lpc], -1)
a_big = torch.cat([a, zeros], -1)
fir_mat_big = filters.toeplitz_from_filter(a_big)
#print(a_big[:,0,:])
for n in range(nb_frames):
for k in range(nb_subframes):
sig = torch.cat([sig[:,subframe_size:], x[:,n*nb_subframes + k, :]], 1)
exc = torch.bmm(fir_mat_big[:,n,:,:], sig[:,:,None])
out = torch.cat([out, exc[:,-subframe_size:,0]], 1)
return out
# weight initialization and clipping
def init_weights(module):
if isinstance(module, nn.GRU):
for p in module.named_parameters():
if p[0].startswith('weight_hh_'):
nn.init.orthogonal_(p[1])
def gen_phase_embedding(periods, frame_size):
device = periods.device
batch_size = periods.size(0)
nb_frames = periods.size(1)
w0 = 2*torch.pi/periods
w0_shift = torch.cat([2*torch.pi*torch.rand((batch_size, 1), device=device)/frame_size, w0[:,:-1]], 1)
cum_phase = frame_size*torch.cumsum(w0_shift, 1)
fine_phase = w0[:,:,None]*torch.broadcast_to(torch.arange(frame_size, device=device), (batch_size, nb_frames, frame_size))
embed = torch.unsqueeze(cum_phase, 2) + fine_phase
embed = torch.reshape(embed, (batch_size, -1))
return torch.cos(embed), torch.sin(embed)
class GLU(nn.Module):
def __init__(self, feat_size, softquant=False):
super(GLU, self).__init__()
torch.manual_seed(5)
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
if softquant:
self.gate = soft_quant(self.gate)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
out = x * torch.sigmoid(self.gate(x))
return out
class FWConv(nn.Module):
def __init__(self, in_size, out_size, kernel_size=2, softquant=False):
super(FWConv, self).__init__()
torch.manual_seed(5)
self.in_size = in_size
self.kernel_size = kernel_size
self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
self.glu = GLU(out_size, softquant=softquant)
if softquant:
self.conv = soft_quant(self.conv)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x, state):
xcat = torch.cat((state, x), -1)
#print(x.shape, state.shape, xcat.shape, self.in_size, self.kernel_size)
out = self.glu(torch.tanh(self.conv(xcat)))
return out, xcat[:,self.in_size:]
def n(x):
return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
class FARGANCond(nn.Module):
def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12, softquant=False):
super(FARGANCond, self).__init__()
self.feature_dim = feature_dim
self.cond_size = cond_size
self.pembed = nn.Embedding(224, pembed_dims)
self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False)
self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
self.fdense2 = nn.Linear(128, 80*4, bias=False)
if softquant:
self.fconv1 = soft_quant(self.fconv1)
self.fdense2 = soft_quant(self.fdense2)
self.apply(init_weights)
nb_params = sum(p.numel() for p in self.parameters())
print(f"cond model: {nb_params} weights")
def forward(self, features, period):
features = features[:,2:,:]
period = period[:,2:]
p = self.pembed(period-32)
features = torch.cat((features, p), -1)
tmp = torch.tanh(self.fdense1(features))
tmp = tmp.permute(0, 2, 1)
tmp = torch.tanh(self.fconv1(tmp))
tmp = tmp.permute(0, 2, 1)
tmp = torch.tanh(self.fdense2(tmp))
#tmp = torch.tanh(self.fdense2(tmp))
return tmp
class FARGANSub(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, softquant=False):
super(FARGANSub, self).__init__()
self.subframe_size = subframe_size
self.nb_subframes = nb_subframes
self.cond_size = cond_size
self.cond_gain_dense = nn.Linear(80, 1)
#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
self.fwc0 = FWConv(2*self.subframe_size+80+4, 192, softquant=softquant)
self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False)
self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False)
self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)
self.gru1_glu = GLU(160, softquant=softquant)
self.gru2_glu = GLU(128, softquant=softquant)
self.gru3_glu = GLU(128, softquant=softquant)
self.skip_glu = GLU(128, softquant=softquant)
#self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False)
self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False)
self.gain_dense_out = nn.Linear(192, 4)
if softquant:
self.gru1 = soft_quant(self.gru1, names=['weight_hh', 'weight_ih'])
self.gru2 = soft_quant(self.gru2, names=['weight_hh', 'weight_ih'])
self.gru3 = soft_quant(self.gru3, names=['weight_hh', 'weight_ih'])
self.skip_dense = soft_quant(self.skip_dense)
self.sig_dense_out = soft_quant(self.sig_dense_out)
self.apply(init_weights)
nb_params = sum(p.numel() for p in self.parameters())
print(f"subframe model: {nb_params} weights")
def forward(self, cond, prev_pred, exc_mem, period, states, gain=None):
device = exc_mem.device
#print(cond.shape, prev.shape)
cond = n(cond)
dump_signal(gain, 'gain0.f32')
gain = torch.exp(self.cond_gain_dense(cond))
dump_signal(gain, 'gain1.f32')
idx = 256-period[:,None]
rng = torch.arange(self.subframe_size+4, device=device)
idx = idx + rng[None,:] - 2
mask = idx >= 256
idx = idx - mask*period[:,None]
pred = torch.gather(exc_mem, 1, idx)
pred = n(pred/(1e-5+gain))
prev = exc_mem[:,-self.subframe_size:]
dump_signal(prev, 'prev_in.f32')
prev = n(prev/(1e-5+gain))
dump_signal(prev, 'pitch_exc.f32')
dump_signal(exc_mem, 'exc_mem.f32')
tmp = torch.cat((cond, pred, prev), 1)
#fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
fpitch = pred[:,2:-2]
#tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
fwc0_out = n(fwc0_out)
pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out))
gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0])
gru1_out = self.gru1_glu(n(gru1_state))
gru1_out = n(gru1_out)
gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1])
gru2_out = self.gru2_glu(n(gru2_state))
gru2_out = n(gru2_out)
gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2])
gru3_out = self.gru3_glu(n(gru3_state))
gru3_out = n(gru3_out)
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1)
skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1)))
skip_out = self.skip_glu(n(skip_out))
sig_out = torch.tanh(self.sig_dense_out(skip_out))
dump_signal(sig_out, 'exc_out.f32')
#taps = self.ptaps_dense(gru3_out)
#taps = .2*taps + torch.exp(taps)
#taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
#dump_signal(taps, 'taps.f32')
dump_signal(pitch_gain, 'pgain.f32')
#sig_out = (sig_out + pitch_gain*fpitch) * gain
sig_out = sig_out * gain
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1)
dump_signal(sig_out, 'sig_out.f32')
return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state)
class FARGAN(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None, softquant=False):
super(FARGAN, self).__init__()
self.subframe_size = subframe_size
self.nb_subframes = nb_subframes
self.frame_size = self.subframe_size*self.nb_subframes
self.feature_dim = feature_dim
self.cond_size = cond_size
self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size, softquant=softquant)
self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, softquant=softquant)
def forward(self, features, period, nb_frames, pre=None, states=None):
device = features.device
batch_size = features.size(0)
prev = torch.zeros(batch_size, 256, device=device)
exc_mem = torch.zeros(batch_size, 256, device=device)
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
states = (
torch.zeros(batch_size, 160, device=device),
torch.zeros(batch_size, 128, device=device),
torch.zeros(batch_size, 128, device=device),
torch.zeros(batch_size, (2*self.subframe_size+80+4)*1, device=device)
)
sig = torch.zeros((batch_size, 0), device=device)
cond = self.cond_net(features, period)
if pre is not None:
exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size]
start = 1 if nb_pre_frames>0 else 0
for n in range(start, nb_frames+nb_pre_frames):
for k in range(self.nb_subframes):
pos = n*self.frame_size + k*self.subframe_size
#print("now: ", preal.shape, prev.shape, sig_in.shape)
pitch = period[:, 3+n]
gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
#gain = gain[:,:,None]
out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain)
if n < nb_pre_frames:
out = pre[:, pos:pos+self.subframe_size]
exc_mem[:,-self.subframe_size:] = out
else:
sig = torch.cat([sig, out], 1)
states = [s.detach() for s in states]
return sig, states

View File

@@ -0,0 +1,46 @@
import torch
from torch import nn
import torch.nn.functional as F
import math
def toeplitz_from_filter(a):
device = a.device
L = a.size(-1)
size0 = (*(a.shape[:-1]), L, L+1)
size = (*(a.shape[:-1]), L, L)
rnge = torch.arange(0, L, dtype=torch.int64, device=device)
z = torch.tensor(0, device=device)
idx = torch.maximum(rnge[:,None] - rnge[None,:] + 1, z)
a = torch.cat([a[...,:1]*0, a], -1)
#print(a)
a = a[...,None,:]
#print(idx)
a = torch.broadcast_to(a, size0)
idx = torch.broadcast_to(idx, size)
#print(idx)
return torch.gather(a, -1, idx)
def filter_iir_response(a, N):
device = a.device
L = a.size(-1)
ar = a.flip(dims=(2,))
size = (*(a.shape[:-1]), N)
R = torch.zeros(size, device=device)
R[:,:,0] = torch.ones((a.shape[:-1]), device=device)
for i in range(1, L):
R[:,:,i] = - torch.sum(ar[:,:,L-i-1:-1] * R[:,:,:i], axis=-1)
#R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,L-i-1:-1], R[:,:,:i])
for i in range(L, N):
R[:,:,i] = - torch.sum(ar[:,:,:-1] * R[:,:,i-L+1:i], axis=-1)
#R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,:-1], R[:,:,i-L+1:i])
return R
if __name__ == '__main__':
#a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]], [[1, .9, 0], [1, .8, 0]]])
a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]]])
A = toeplitz_from_filter(a)
#print(A)
R = filter_iir_response(a, 5)
RA = toeplitz_from_filter(R)
print(RA)

View File

@@ -0,0 +1,29 @@
import torch
def rc2lpc(rc):
order = rc.shape[-1]
lpc=rc[...,0:1]
for i in range(1, order):
lpc = torch.cat([lpc + rc[...,i:i+1]*torch.flip(lpc,dims=(-1,)), rc[...,i:i+1]], -1)
#print("to:", lpc)
return lpc
def lpc2rc(lpc):
order = lpc.shape[-1]
rc = lpc[...,-1:]
for i in range(order-1, 0, -1):
ki = lpc[...,-1:]
lpc = lpc[...,:-1]
lpc = (lpc - ki*torch.flip(lpc,dims=(-1,)))/(1 - ki*ki)
rc = torch.cat([lpc[...,-1:] , rc], -1)
return rc
if __name__ == "__main__":
rc = torch.tensor([[.5, -.5, .6, -.6]])
print(rc)
lpc = rc2lpc(rc)
print(lpc)
rc2 = lpc2rc(lpc)
print(rc2)

View File

@@ -0,0 +1,186 @@
"""STFT-based Loss modules."""
import torch
import torch.nn.functional as F
import numpy as np
import torchaudio
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
#x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
#real = x_stft[..., 0]
#imag = x_stft[..., 1]
# (kan-bayashi): clamp is needed to avoid nan or inf
#return torchaudio.functional.amplitude_to_DB(torch.abs(x_stft),db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
#return torch.clamp(torch.abs(x_stft), min=1e-7)
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
return torch.clamp(torch.abs(x_stft), min=1e-7)
class SpectralConvergenceLoss(torch.nn.Module):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super(SpectralConvergenceLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
x_mag = torch.sqrt(x_mag)
y_mag = torch.sqrt(y_mag)
return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
def __init__(self):
"""Initilize los STFT magnitude loss module."""
super(LogSTFTMagnitudeLoss, self).__init__()
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
#F.l1_loss(torch.sqrt(y_mag), torch.sqrt(x_mag)) +
#F.l1_loss(torchaudio.functional.amplitude_to_DB(y_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80),\
#torchaudio.functional.amplitude_to_DB(x_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80))
#y_mag[:,:y_mag.size(1)//2,:] = y_mag[:,:y_mag.size(1)//2,:] *0.0
#return F.l1_loss(torch.log(y_mag) + torch.sqrt(y_mag), torch.log(x_mag) + torch.sqrt(x_mag))
#return F.l1_loss(y_mag, x_mag)
error_loss = F.l1_loss(y, x) #+ F.l1_loss(torch.sqrt(y), torch.sqrt(x))#F.l1_loss(torch.log(y), torch.log(x))#
#x = torch.log(x)
#y = torch.log(y)
#x = x.permute(0,2,1).contiguous()
#y = y.permute(0,2,1).contiguous()
'''mean_x = torch.mean(x, dim=1, keepdim=True)
mean_y = torch.mean(y, dim=1, keepdim=True)
var_x = torch.var(x, dim=1, keepdim=True)
var_y = torch.var(y, dim=1, keepdim=True)
std_x = torch.std(x, dim=1, keepdim=True)
std_y = torch.std(y, dim=1, keepdim=True)
x_minus_mean = x - mean_x
y_minus_mean = y - mean_y
pearson_corr = torch.sum(x_minus_mean * y_minus_mean, dim=1, keepdim=True) / \
(torch.sqrt(torch.sum(x_minus_mean ** 2, dim=1, keepdim=True) + 1e-7) * \
torch.sqrt(torch.sum(y_minus_mean ** 2, dim=1, keepdim=True) + 1e-7))
numerator = 2.0 * pearson_corr * std_x * std_y
denominator = var_x + var_y + (mean_y - mean_x)**2
ccc = numerator/denominator
ccc_loss = F.l1_loss(1.0 - ccc, torch.zeros_like(ccc))'''
return error_loss #+ ccc_loss#+ ccc_loss
class STFTLoss(torch.nn.Module):
"""STFT loss module."""
def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
"""Initialize STFT loss module."""
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length).to(device)
self.spectral_convergenge_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(torch.nn.Module):
'''def __init__(self,
device,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
hop_sizes=[512, 256, 128, 64, 32, 16],
win_lengths=[2048, 1024, 512, 256, 128, 64],
window="hann_window"):'''
'''def __init__(self,
device,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
hop_sizes=[256, 128, 64, 32, 16, 8],
win_lengths=[1024, 512, 256, 128, 64, 32],
window="hann_window"):'''
def __init__(self,
device,
fft_sizes=[2560, 1280, 640, 320, 160, 80],
hop_sizes=[640, 320, 160, 80, 40, 20],
win_lengths=[2560, 1280, 640, 320, 160, 80],
window="hann_window"):
super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [STFTLoss(device, fs, ss, wl, window)]
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
#mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss #mag_loss #+

View File

@@ -0,0 +1,128 @@
import os
import argparse
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
import fargan
from dataset import FARGANDataset
nb_features = 36
nb_used_features = 20
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='CELPNet model')
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
parser.add_argument('output', type=str, help='path to output file (16-bit PCM)')
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
args = parser.parse_args()
if args.cuda_visible_devices != None:
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
features_file = args.features
signal_file = args.output
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint = torch.load(args.model, map_location='cpu')
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
model.load_state_dict(checkpoint['state_dict'], strict=False)
features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features))
lpc = features[:,4-1:-1,nb_used_features:]
features = features[:, :, :nb_used_features]
#periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
periods = np.round(np.clip(256./2**(features[:,:,nb_used_features-2]+1.5), 32, 255)).astype('int')
nb_frames = features.shape[1]
#nb_frames = 1000
gamma = checkpoint['model_kwargs']['gamma']
def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
out = np.zeros_like(frame)
filt = np.flip(filt)
inp = frame[:]
for i in range(0, inp.shape[0]):
s = inp[i] - np.dot(buffer*weighting_vector, filt)
buffer[0] = s
buffer = np.roll(buffer, -1)
out[i] = s
return out
def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
#inverse perceptual weighting= H_preemph / W(z/gamma)
signal = np.zeros_like(pw_signal)
buffer = np.zeros(16)
num_frames = pw_signal.shape[0] //160
assert num_frames == filters.shape[0]
for frame_idx in range(0, num_frames):
in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
buffer[:] = out_sig_frame[-16:]
return signal
def inverse_perceptual_weighting40 (pw_signal, filters):
#inverse perceptual weighting= H_preemph / W(z/gamma)
signal = np.zeros_like(pw_signal)
buffer = np.zeros(16)
num_frames = pw_signal.shape[0] //40
assert num_frames == filters.shape[0]
for frame_idx in range(0, num_frames):
in_frame = pw_signal[frame_idx*40: (frame_idx+1)*40][:]
out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer)
signal[frame_idx*40: (frame_idx+1)*40] = out_sig_frame[:]
buffer[:] = out_sig_frame[-16:]
return signal
from scipy.signal import lfilter
if __name__ == '__main__':
model.to(device)
features = torch.tensor(features).to(device)
#lpc = torch.tensor(lpc).to(device)
periods = torch.tensor(periods).to(device)
weighting = gamma**np.arange(1, 17)
lpc = lpc*weighting
lpc = fargan.interp_lpc(torch.tensor(lpc), 4).numpy()
sig, _ = model(features, periods, nb_frames - 4)
#weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
sig = sig.detach().numpy().flatten()
sig = lfilter(np.array([1.]), np.array([1., -.85]), sig)
#sig = inverse_perceptual_weighting40(sig, lpc[0,:,:])
pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
pcm.tofile(signal_file)

View File

@@ -0,0 +1,169 @@
import os
import argparse
import random
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
import fargan
from dataset import FARGANDataset
from stft_loss import *
parser = argparse.ArgumentParser()
parser.add_argument('features', type=str, help='path to feature file in .f32 format')
parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
parser.add_argument('output', type=str, help='path to output folder')
parser.add_argument('--suffix', type=str, help="model name suffix", default="")
parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
model_group = parser.add_argument_group(title="model parameters")
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
training_group = parser.add_argument_group(title="training parameters")
training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3)
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20)
training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15)
training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4)
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
args = parser.parse_args()
if args.cuda_visible_devices != None:
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
# checkpoints
checkpoint_dir = os.path.join(args.output, 'checkpoints')
checkpoint = dict()
os.makedirs(checkpoint_dir, exist_ok=True)
# training parameters
batch_size = args.batch_size
lr = args.lr
epochs = args.epochs
sequence_length = args.sequence_length
lr_decay = args.lr_decay
adam_betas = [0.8, 0.95]
adam_eps = 1e-8
features_file = args.features
signal_file = args.signal
# model parameters
cond_size = args.cond_size
checkpoint['batch_size'] = batch_size
checkpoint['lr'] = lr
checkpoint['lr_decay'] = lr_decay
checkpoint['epochs'] = epochs
checkpoint['sequence_length'] = sequence_length
checkpoint['adam_betas'] = adam_betas
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint['model_args'] = ()
checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant}
print(checkpoint['model_kwargs'])
model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
#model = fargan.FARGAN()
#model = nn.DataParallel(model)
if type(args.initial_checkpoint) != type(None):
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
checkpoint['state_dict'] = model.state_dict()
dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
# learning rate scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
states = None
spect_loss = MultiResolutionSTFTLoss(device).to(device)
if __name__ == '__main__':
model.to(device)
for epoch in range(1, epochs + 1):
running_specc = 0
running_cont_loss = 0
running_loss = 0
print(f"training epoch {epoch}...")
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (features, periods, target, lpc) in enumerate(tepoch):
optimizer.zero_grad()
features = features.to(device)
#lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4))
#print("interp size", lpc.shape)
#lpc = lpc.to(device)
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
#lpc = fargan.interp_lpc(lpc, 4)
periods = periods.to(device)
if (np.random.rand() > 0.1):
target = target[:, :sequence_length*160]
#lpc = lpc[:,:sequence_length*4,:]
features = features[:,:sequence_length+4,:]
periods = periods[:,:sequence_length+4]
else:
target=target[::2, :]
#lpc=lpc[::2,:]
features=features[::2,:]
periods=periods[::2,:]
target = target.to(device)
#print(target.shape, lpc.shape)
#target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
#nb_pre = random.randrange(1, 6)
nb_pre = 2
pre = target[:, :nb_pre*160]
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
sig = torch.cat([pre, sig], -1)
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160])
specc_loss = spect_loss(sig, target.detach())
loss = .03*cont_loss + specc_loss
loss.backward()
optimizer.step()
#model.clip_weights()
scheduler.step()
running_specc += specc_loss.detach().cpu().item()
running_cont_loss += cont_loss.detach().cpu().item()
running_loss += loss.detach().cpu().item()
tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
cont_loss=f"{running_cont_loss/(i+1):8.5f}",
specc=f"{running_specc/(i+1):8.5f}",
)
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth')
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = running_loss / len(dataloader)
checkpoint['epoch'] = epoch
torch.save(checkpoint, checkpoint_path)