add some code
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
from . import quantization
|
||||
from . import sparsification
|
||||
@@ -0,0 +1 @@
|
||||
from .softquant import soft_quant, remove_soft_quant
|
||||
@@ -0,0 +1,113 @@
|
||||
import torch
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_optimal_scale(weight):
|
||||
with torch.no_grad():
|
||||
n_out, n_in = weight.shape
|
||||
assert n_in % 4 == 0
|
||||
if n_out % 8:
|
||||
# add padding
|
||||
pad = n_out - n_out % 8
|
||||
weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
|
||||
|
||||
weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
|
||||
weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
|
||||
scale_max = weight_max_abs / 127
|
||||
scale_sum = weight_max_sum / 129
|
||||
|
||||
scale = torch.maximum(scale_max, scale_sum)
|
||||
|
||||
return scale[:n_out]
|
||||
|
||||
@torch.no_grad()
|
||||
def q_scaled_noise(module, weight):
|
||||
if isinstance(module, torch.nn.Conv1d):
|
||||
w = weight.permute(0, 2, 1).flatten(1)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
noise[w == 0] = 0 # ignore zero entries from sparsification
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
|
||||
elif isinstance(module, torch.nn.ConvTranspose1d):
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(2, 1, 0).reshape(k * o, i)
|
||||
noise = torch.rand_like(w) - 0.5
|
||||
noise[w == 0] = 0 # ignore zero entries from sparsification
|
||||
scale = compute_optimal_scale(w)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
noise = noise.reshape(k, o, i).permute(2, 1, 0)
|
||||
elif len(weight.shape) == 2:
|
||||
noise = torch.rand_like(weight) - 0.5
|
||||
noise[weight == 0] = 0 # ignore zero entries from sparsification
|
||||
scale = compute_optimal_scale(weight)
|
||||
noise = noise * scale.unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError('unknown quantization setting')
|
||||
|
||||
return noise
|
||||
|
||||
class SoftQuant:
|
||||
name: str
|
||||
|
||||
def __init__(self, names: str, scale: float) -> None:
|
||||
self.names = names
|
||||
self.quantization_noise = None
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, module, inputs, *args, before=True):
|
||||
if not module.training: return
|
||||
|
||||
if before:
|
||||
self.quantization_noise = dict()
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
if self.scale is None:
|
||||
self.quantization_noise[name] = q_scaled_noise(module, weight)
|
||||
else:
|
||||
self.quantization_noise[name] = \
|
||||
self.scale * (torch.rand_like(weight) - 0.5)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight + self.quantization_noise[name]
|
||||
else:
|
||||
for name in self.names:
|
||||
weight = getattr(module, name)
|
||||
with torch.no_grad():
|
||||
weight.data[:] = weight - self.quantization_noise[name]
|
||||
self.quantization_noise = None
|
||||
|
||||
def apply(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant(names, scale)
|
||||
|
||||
for name in names:
|
||||
if not hasattr(module, name):
|
||||
raise ValueError("")
|
||||
|
||||
fn_before = lambda *x : fn(*x, before=True)
|
||||
fn_after = lambda *x : fn(*x, before=False)
|
||||
setattr(fn_before, 'sqm', fn)
|
||||
setattr(fn_after, 'sqm', fn)
|
||||
|
||||
|
||||
module.register_forward_pre_hook(fn_before)
|
||||
module.register_forward_hook(fn_after)
|
||||
|
||||
module
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def soft_quant(module, names=['weight'], scale=None):
|
||||
fn = SoftQuant.apply(module, names, scale)
|
||||
return module
|
||||
|
||||
def remove_soft_quant(module, names=['weight']):
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_pre_hooks[k]
|
||||
for k, hook in module._forward_hooks.items():
|
||||
if hasattr(hook, 'sqm'):
|
||||
if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
|
||||
del module._forward_hooks[k]
|
||||
|
||||
return module
|
||||
@@ -0,0 +1,2 @@
|
||||
from .relegance import relegance_gradient_weighting, relegance_create_tconv_kernel, relegance_map_relevance_to_input_domain, relegance_resize_relevance_to_input_size
|
||||
from .meta_critic import MetaCritic
|
||||
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
class MetaCritic():
|
||||
def __init__(self, normalize=False, gamma=0.9, beta=0.0, joint_stats=False):
|
||||
""" Class for assessing relevance of discriminator scores
|
||||
|
||||
Args:
|
||||
gamma (float, optional): update rate for tracking discriminator stats. Defaults to 0.9.
|
||||
beta (float, optional): Miminum confidence related threshold. Defaults to 0.0.
|
||||
"""
|
||||
self.normalize = normalize
|
||||
self.gamma = gamma
|
||||
self.beta = beta
|
||||
self.joint_stats = joint_stats
|
||||
|
||||
self.disc_stats = dict()
|
||||
|
||||
def __call__(self, disc_id, real_scores, generated_scores):
|
||||
""" calculates relevance from normalized scores
|
||||
|
||||
Args:
|
||||
disc_id (any valid key): id for tracking discriminator statistics
|
||||
real_scores (torch.tensor): scores for real data
|
||||
generated_scores (torch.tensor): scores for generated data; expecting device to match real_scores.device
|
||||
|
||||
Returns:
|
||||
torch.tensor: output-domain relevance
|
||||
"""
|
||||
|
||||
if self.normalize:
|
||||
real_std = torch.std(real_scores.detach()).cpu().item()
|
||||
gen_std = torch.std(generated_scores.detach()).cpu().item()
|
||||
std = (real_std**2 + gen_std**2) ** .5
|
||||
mean = torch.mean(real_scores.detach()).cpu().item() - torch.mean(generated_scores.detach()).cpu().item()
|
||||
|
||||
key = 0 if self.joint_stats else disc_id
|
||||
|
||||
if key in self.disc_stats:
|
||||
self.disc_stats[key]['std'] = self.gamma * self.disc_stats[key]['std'] + (1 - self.gamma) * std
|
||||
self.disc_stats[key]['mean'] = self.gamma * self.disc_stats[key]['mean'] + (1 - self.gamma) * mean
|
||||
else:
|
||||
self.disc_stats[key] = {
|
||||
'std': std + 1e-5,
|
||||
'mean': mean
|
||||
}
|
||||
|
||||
std = self.disc_stats[key]['std']
|
||||
mean = self.disc_stats[key]['mean']
|
||||
else:
|
||||
mean, std = 0, 1
|
||||
|
||||
relevance = torch.relu((real_scores - generated_scores - mean) / std + mean - self.beta)
|
||||
|
||||
if False: print(f"relevance({disc_id}): {relevance.min()=} {relevance.max()=} {relevance.mean()=}")
|
||||
|
||||
return relevance
|
||||
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def view_one_hot(index, length):
|
||||
vec = length * [1]
|
||||
vec[index] = -1
|
||||
return vec
|
||||
|
||||
def create_smoothing_kernel(widths, gamma=1.5):
|
||||
""" creates a truncated gaussian smoothing kernel for the given widths
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
widths: list[Int] or torch.LongTensor
|
||||
specifies the shape of the smoothing kernel, entries must be > 0.
|
||||
|
||||
gamma: float, optional
|
||||
decay factor for gaussian relative to kernel size
|
||||
|
||||
Returns:
|
||||
--------
|
||||
kernel: torch.FloatTensor
|
||||
"""
|
||||
|
||||
widths = torch.LongTensor(widths)
|
||||
num_dims = len(widths)
|
||||
|
||||
assert(widths.min() > 0)
|
||||
|
||||
centers = widths.float() / 2 - 0.5
|
||||
sigmas = gamma * (centers + 1)
|
||||
|
||||
vals = []
|
||||
|
||||
vals= [((torch.arange(widths[i]) - centers[i]) / sigmas[i]) ** 2 for i in range(num_dims)]
|
||||
vals = sum([vals[i].view(view_one_hot(i, num_dims)) for i in range(num_dims)])
|
||||
|
||||
kernel = torch.exp(- vals)
|
||||
kernel = kernel / kernel.sum()
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def create_partition_kernel(widths, strides):
|
||||
""" creates a partition kernel for mapping a convolutional network output back to the input domain
|
||||
|
||||
Given a fully convolutional network with receptive field of shape widths and the given strides, this
|
||||
function construncts an intorpolation kernel whose tranlations by multiples of the given strides form
|
||||
a partition of one on the input domain.
|
||||
|
||||
Parameter:
|
||||
----------
|
||||
widths: list[Int] or torch.LongTensor
|
||||
shape of receptive field
|
||||
|
||||
strides: list[Int] or torch.LongTensor
|
||||
total strides of convolutional network
|
||||
|
||||
Returns:
|
||||
kernel: torch.FloatTensor
|
||||
"""
|
||||
|
||||
num_dims = len(widths)
|
||||
assert num_dims == len(strides) and num_dims in {1, 2, 3}
|
||||
|
||||
convs = {1 : F.conv1d, 2 : F.conv2d, 3 : F.conv3d}
|
||||
|
||||
widths = torch.LongTensor(widths)
|
||||
strides = torch.LongTensor(strides)
|
||||
|
||||
proto_kernel = torch.ones(torch.minimum(strides, widths).tolist())
|
||||
|
||||
# create interpolation kernel eta
|
||||
eta_widths = widths - strides + 1
|
||||
if eta_widths.min() <= 0:
|
||||
print("[create_partition_kernel] warning: receptive field does not cover input domain")
|
||||
eta_widths = torch.maximum(eta_widths, torch.ones_like(eta_widths))
|
||||
|
||||
|
||||
eta = create_smoothing_kernel(eta_widths).view(1, 1, *eta_widths.tolist())
|
||||
|
||||
padding = torch.repeat_interleave(eta_widths - 1, 2, 0).tolist()[::-1] # ordering of dimensions for padding and convolution functions is reversed in torch
|
||||
padded_proto_kernel = F.pad(proto_kernel, padding)
|
||||
padded_proto_kernel = padded_proto_kernel.view(1, 1, *padded_proto_kernel.shape)
|
||||
kernel = convs[num_dims](padded_proto_kernel, eta)
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def receptive_field(conv_model, input_shape, output_position):
|
||||
""" estimates boundaries of receptive field connected to output_position via autograd
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
conv_model: nn.Module or autograd function
|
||||
function or model implementing fully convolutional model
|
||||
|
||||
input_shape: List[Int]
|
||||
input shape ignoring batch dimension, i.e. [num_channels, dim1, dim2, ...]
|
||||
|
||||
output_position: List[Int]
|
||||
output position for which the receptive field is determined; the function raises an exception
|
||||
if output_position is out of bounds for the given input_shape.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
low: List[Int]
|
||||
start indices of receptive field
|
||||
|
||||
high: List[Int]
|
||||
stop indices of receptive field
|
||||
|
||||
"""
|
||||
|
||||
x = torch.randn((1,) + tuple(input_shape), requires_grad=True)
|
||||
y = conv_model(x)
|
||||
|
||||
# collapse channels and remove batch dimension
|
||||
y = torch.sum(y, 1)[0]
|
||||
|
||||
# create mask
|
||||
mask = torch.zeros_like(y)
|
||||
index = [torch.tensor(i) for i in output_position]
|
||||
try:
|
||||
mask.index_put_(index, torch.tensor(1, dtype=mask.dtype))
|
||||
except IndexError:
|
||||
raise ValueError('output_position out of bounds')
|
||||
|
||||
(mask * y).sum().backward()
|
||||
|
||||
# sum over channels and remove batch dimension
|
||||
grad = torch.sum(x.grad, dim=1)[0]
|
||||
tmp = torch.nonzero(grad, as_tuple=True)
|
||||
low = [t.min().item() for t in tmp]
|
||||
high = [t.max().item() for t in tmp]
|
||||
|
||||
return low, high
|
||||
|
||||
def estimate_conv_parameters(model, num_channels, num_dims, width, max_stride=10):
|
||||
""" attempts to estimate receptive field size, strides and left paddings for given model
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
model: nn.Module or autograd function
|
||||
fully convolutional model for which parameters are estimated
|
||||
|
||||
num_channels: Int
|
||||
number of input channels for model
|
||||
|
||||
num_dims: Int
|
||||
number of input dimensions for model (without channel dimension)
|
||||
|
||||
width: Int
|
||||
width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
|
||||
|
||||
max_stride: Int, optional
|
||||
assumed maximal stride of the model for any dimension, when set too low the function may fail for
|
||||
any value of width
|
||||
|
||||
Returns:
|
||||
--------
|
||||
receptive_field_size: List[Int]
|
||||
receptive field size in all dimension
|
||||
|
||||
strides: List[Int]
|
||||
stride in all dimensions
|
||||
|
||||
left_paddings: List[Int]
|
||||
left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
|
||||
|
||||
Raises:
|
||||
-------
|
||||
ValueError, KeyError
|
||||
|
||||
"""
|
||||
|
||||
input_shape = [num_channels] + num_dims * [width]
|
||||
output_position1 = num_dims * [width // (2 * max_stride)]
|
||||
output_position2 = num_dims * [width // (2 * max_stride) + 1]
|
||||
|
||||
low1, high1 = receptive_field(model, input_shape, output_position1)
|
||||
low2, high2 = receptive_field(model, input_shape, output_position2)
|
||||
|
||||
widths1 = [h - l + 1 for l, h in zip(low1, high1)]
|
||||
widths2 = [h - l + 1 for l, h in zip(low2, high2)]
|
||||
|
||||
if not all([w1 - w2 == 0 for w1, w2 in zip(widths1, widths2)]) or not all([l1 != l2 for l1, l2 in zip(low1, low2)]):
|
||||
raise ValueError("[estimate_strides]: widths to small to determine strides")
|
||||
|
||||
receptive_field_size = widths1
|
||||
strides = [l2 - l1 for l1, l2 in zip(low1, low2)]
|
||||
left_paddings = [s * p - l for l, s, p in zip(low1, strides, output_position1)]
|
||||
|
||||
return receptive_field_size, strides, left_paddings
|
||||
|
||||
def inspect_conv_model(model, num_channels, num_dims, max_width=10000, width_hint=None, stride_hint=None, verbose=False):
|
||||
""" determines size of receptive field, strides and padding probabilistically
|
||||
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
model: nn.Module or autograd function
|
||||
fully convolutional model for which parameters are estimated
|
||||
|
||||
num_channels: Int
|
||||
number of input channels for model
|
||||
|
||||
num_dims: Int
|
||||
number of input dimensions for model (without channel dimension)
|
||||
|
||||
max_width: Int
|
||||
maximum width of the input tensor (a hyper-square) on which the receptive fields are derived via autograd
|
||||
|
||||
verbose: bool, optional
|
||||
if true, the function prints parameters for individual trials
|
||||
|
||||
Returns:
|
||||
--------
|
||||
receptive_field_size: List[Int]
|
||||
receptive field size in all dimension
|
||||
|
||||
strides: List[Int]
|
||||
stride in all dimensions
|
||||
|
||||
left_paddings: List[Int]
|
||||
left padding in all dimensions; this is relevant for aligning the receptive field on the input plane
|
||||
|
||||
Raises:
|
||||
-------
|
||||
ValueError
|
||||
|
||||
"""
|
||||
|
||||
max_stride = max_width // 2
|
||||
stride = max_stride // 100
|
||||
width = max_width // 100
|
||||
|
||||
if width_hint is not None: width = 2 * width_hint
|
||||
if stride_hint is not None: stride = stride_hint
|
||||
|
||||
did_it = False
|
||||
while width < max_width and stride < max_stride:
|
||||
try:
|
||||
if verbose: print(f"[inspect_conv_model] trying parameters {width=}, {stride=}")
|
||||
receptive_field_size, strides, left_paddings = estimate_conv_parameters(model, num_channels, num_dims, width, stride)
|
||||
did_it = True
|
||||
except:
|
||||
pass
|
||||
|
||||
if did_it: break
|
||||
|
||||
width *= 2
|
||||
if width >= max_width and stride < max_stride:
|
||||
stride *= 2
|
||||
width = 2 * stride
|
||||
|
||||
if not did_it:
|
||||
raise ValueError(f'could not determine conv parameter with given max_width={max_width}')
|
||||
|
||||
return receptive_field_size, strides, left_paddings
|
||||
|
||||
|
||||
class GradWeight(torch.autograd.Function):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight):
|
||||
ctx.save_for_backward(weight)
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
weight, = ctx.saved_tensors
|
||||
|
||||
grad_input = grad_output * weight
|
||||
|
||||
return grad_input, None
|
||||
|
||||
|
||||
# API
|
||||
|
||||
def relegance_gradient_weighting(x, weight):
|
||||
"""
|
||||
|
||||
Args:
|
||||
x (torch.tensor): input tensor
|
||||
weight (torch.tensor or None): weight tensor for gradients of x; if None, no gradient weighting will be applied in backward pass
|
||||
|
||||
Returns:
|
||||
torch.tensor: the unmodified input tensor x
|
||||
|
||||
Raises:
|
||||
RuntimeError: if estimation of parameters fails due to exceeded compute budget
|
||||
"""
|
||||
if weight is None:
|
||||
return x
|
||||
else:
|
||||
return GradWeight.apply(x, weight)
|
||||
|
||||
|
||||
|
||||
def relegance_create_tconv_kernel(model, num_channels, num_dims, width_hint=None, stride_hint=None, verbose=False):
|
||||
""" creates parameters for mapping back output domain relevance to input tomain
|
||||
|
||||
Args:
|
||||
model (nn.Module or autograd.Function): fully convolutional model
|
||||
num_channels (int): number of input channels to model
|
||||
num_dims (int): number of input dimensions of model (without channel and batch dimension)
|
||||
width_hint(int or None): optional hint at maximal width of receptive field
|
||||
stride_hint(int or None): optional hint at maximal stride
|
||||
|
||||
Returns:
|
||||
dict: contains kernel, kernel dimensions, strides and left paddings for transposed convolution
|
||||
"""
|
||||
|
||||
max_width = int(100000 / (10 ** num_dims))
|
||||
|
||||
did_it = False
|
||||
try:
|
||||
receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
|
||||
did_it = True
|
||||
except:
|
||||
# try once again with larger max_width
|
||||
max_width *= 10
|
||||
|
||||
# crash if exception is raised
|
||||
try:
|
||||
if not did_it: receptive_field_size, strides, left_paddings = inspect_conv_model(model, num_channels, num_dims, max_width=max_width, width_hint=width_hint, stride_hint=stride_hint, verbose=verbose)
|
||||
except:
|
||||
raise RuntimeError("could not determine parameters within given compute budget")
|
||||
|
||||
partition_kernel = create_partition_kernel(receptive_field_size, strides)
|
||||
partition_kernel = torch.repeat_interleave(partition_kernel, num_channels, 1)
|
||||
|
||||
tconv_parameters = {
|
||||
'kernel': partition_kernel,
|
||||
'receptive_field_shape': receptive_field_size,
|
||||
'stride': strides,
|
||||
'left_padding': left_paddings,
|
||||
'num_dims': num_dims
|
||||
}
|
||||
|
||||
return tconv_parameters
|
||||
|
||||
|
||||
|
||||
def relegance_map_relevance_to_input_domain(od_relevance, tconv_parameters):
|
||||
""" maps output-domain relevance to input-domain relevance via transpose convolution
|
||||
|
||||
Args:
|
||||
od_relevance (torch.tensor): output-domain relevance
|
||||
tconv_parameters (dict): parameter dict as created by relegance_create_tconv_kernel
|
||||
|
||||
Returns:
|
||||
torch.tensor: input-domain relevance. The tensor is left aligned, i.e. the all-zero index of the output corresponds to the all-zero index of the discriminator input.
|
||||
Otherwise, the size of the output tensor does not need to match the size of the discriminator input. Use relegance_resize_relevance_to_input_size for a
|
||||
convenient way to adjust the output to the correct size.
|
||||
|
||||
Raises:
|
||||
ValueError: if number of dimensions is not supported
|
||||
"""
|
||||
|
||||
kernel = tconv_parameters['kernel'].to(od_relevance.device)
|
||||
rf_shape = tconv_parameters['receptive_field_shape']
|
||||
stride = tconv_parameters['stride']
|
||||
left_padding = tconv_parameters['left_padding']
|
||||
|
||||
num_dims = len(kernel.shape) - 2
|
||||
|
||||
# repeat boundary values
|
||||
od_padding = [rf_shape[i//2] // stride[i//2] + 1 for i in range(2 * num_dims)]
|
||||
padded_od_relevance = F.pad(od_relevance, od_padding[::-1], mode='replicate')
|
||||
od_padding = od_padding[::2]
|
||||
|
||||
# apply mapping and left trimming
|
||||
if num_dims == 1:
|
||||
id_relevance = F.conv_transpose1d(padded_od_relevance, kernel, stride=stride)
|
||||
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :]
|
||||
elif num_dims == 2:
|
||||
id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
|
||||
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:]
|
||||
elif num_dims == 3:
|
||||
id_relevance = F.conv_transpose2d(padded_od_relevance, kernel, stride=stride)
|
||||
id_relevance = id_relevance[..., left_padding[0] + stride[0] * od_padding[0] :, left_padding[1] + stride[1] * od_padding[1]:, left_padding[2] + stride[2] * od_padding[2] :]
|
||||
else:
|
||||
raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
|
||||
|
||||
return id_relevance
|
||||
|
||||
|
||||
def relegance_resize_relevance_to_input_size(reference_input, relevance):
|
||||
""" adjusts size of relevance tensor to reference input size
|
||||
|
||||
Args:
|
||||
reference_input (torch.tensor): discriminator input tensor for reference
|
||||
relevance (torch.tensor): input-domain relevance corresponding to input tensor reference_input
|
||||
|
||||
Returns:
|
||||
torch.tensor: resized relevance
|
||||
|
||||
Raises:
|
||||
ValueError: if number of dimensions is not supported
|
||||
"""
|
||||
resized_relevance = torch.zeros_like(reference_input)
|
||||
|
||||
num_dims = len(reference_input.shape) - 2
|
||||
with torch.no_grad():
|
||||
if num_dims == 1:
|
||||
resized_relevance[:] = relevance[..., : min(reference_input.size(-1), relevance.size(-1))]
|
||||
elif num_dims == 2:
|
||||
resized_relevance[:] = relevance[..., : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
|
||||
elif num_dims == 3:
|
||||
resized_relevance[:] = relevance[..., : min(reference_input.size(-3), relevance.size(-3)), : min(reference_input.size(-2), relevance.size(-2)), : min(reference_input.size(-1), relevance.size(-1))]
|
||||
else:
|
||||
raise ValueError(f'[relegance_map_to_input_domain] error: num_dims = {num_dims} not supported')
|
||||
|
||||
return resized_relevance
|
||||
@@ -0,0 +1,6 @@
|
||||
from .gru_sparsifier import GRUSparsifier
|
||||
from .conv1d_sparsifier import Conv1dSparsifier
|
||||
from .conv_transpose1d_sparsifier import ConvTranspose1dSparsifier
|
||||
from .linear_sparsifier import LinearSparsifier
|
||||
from .common import sparsify_matrix, calculate_gru_flops_per_step
|
||||
from .utils import mark_for_sparsification, create_sparsifier
|
||||
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
class BaseSparsifier:
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
|
||||
# just copying parameters...
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.interval = interval
|
||||
self.exponent = exponent
|
||||
self.task_list = task_list
|
||||
|
||||
# ... and setting counter to 0
|
||||
self.step_counter = 0
|
||||
|
||||
def step(self, verbose=False):
|
||||
# compute current interpolation factor
|
||||
self.step_counter += 1
|
||||
|
||||
if self.step_counter < self.start:
|
||||
return
|
||||
elif self.step_counter < self.stop:
|
||||
# update only every self.interval-th interval
|
||||
if self.step_counter % self.interval:
|
||||
return
|
||||
|
||||
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
|
||||
else:
|
||||
alpha = 0
|
||||
|
||||
self.sparsify(alpha, verbose=verbose)
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
debug=True
|
||||
|
||||
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
|
||||
""" sparsifies matrix with specified block size
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
matrix : torch.tensor
|
||||
matrix to sparsify
|
||||
density : int
|
||||
target density
|
||||
block_size : [int, int]
|
||||
block size dimensions
|
||||
keep_diagonal : bool
|
||||
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
|
||||
"""
|
||||
|
||||
m, n = matrix.shape
|
||||
m1, n1 = block_size
|
||||
|
||||
if m % m1 or n % n1:
|
||||
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
|
||||
|
||||
# extract diagonal if keep_diagonal = True
|
||||
if keep_diagonal:
|
||||
if m != n:
|
||||
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
|
||||
|
||||
to_spare = torch.diag(torch.diag(matrix))
|
||||
matrix = matrix - to_spare
|
||||
else:
|
||||
to_spare = torch.zeros_like(matrix)
|
||||
|
||||
# calculate energy in sub-blocks
|
||||
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
|
||||
x = x ** 2
|
||||
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
|
||||
|
||||
number_of_blocks = (m * n) // (m1 * n1)
|
||||
number_of_survivors = round(number_of_blocks * density)
|
||||
|
||||
# masking threshold
|
||||
if number_of_survivors == 0:
|
||||
threshold = 0
|
||||
else:
|
||||
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
|
||||
|
||||
# create mask
|
||||
mask = torch.ones_like(block_energies)
|
||||
mask[block_energies < threshold] = 0
|
||||
mask = torch.repeat_interleave(mask, m1, dim=0)
|
||||
mask = torch.repeat_interleave(mask, n1, dim=1)
|
||||
|
||||
# perform masking
|
||||
masked_matrix = mask * matrix + to_spare
|
||||
|
||||
if return_mask:
|
||||
return masked_matrix, mask
|
||||
else:
|
||||
return masked_matrix
|
||||
|
||||
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
|
||||
input_size = gru.input_size
|
||||
hidden_size = gru.hidden_size
|
||||
flops = 0
|
||||
|
||||
input_density = (
|
||||
sparsification_dict.get('W_ir', [1])[0]
|
||||
+ sparsification_dict.get('W_in', [1])[0]
|
||||
+ sparsification_dict.get('W_iz', [1])[0]
|
||||
) / 3
|
||||
|
||||
recurrent_density = (
|
||||
sparsification_dict.get('W_hr', [1])[0]
|
||||
+ sparsification_dict.get('W_hn', [1])[0]
|
||||
+ sparsification_dict.get('W_hz', [1])[0]
|
||||
) / 3
|
||||
|
||||
# input matrix vector multiplications
|
||||
if not drop_input:
|
||||
flops += 2 * 3 * input_size * hidden_size * input_density
|
||||
|
||||
# recurrent matrix vector multiplications
|
||||
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
|
||||
|
||||
# biases
|
||||
flops += 6 * hidden_size
|
||||
|
||||
# activations estimated by 10 flops per activation
|
||||
flops += 30 * hidden_size
|
||||
|
||||
return flops
|
||||
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix, debug
|
||||
|
||||
|
||||
class Conv1dSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (conv1d, params), where conv1d is an instance
|
||||
of torch.nn.Conv1d and params is a tuple (density, [m, n]),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> conv = torch.nn.Conv1d(8, 16, 8)
|
||||
>>> params = (0.2, [8, 4])
|
||||
>>> sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_mask = None
|
||||
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for conv, params in self.task_list:
|
||||
# reshape weight
|
||||
if hasattr(conv, 'weight_v'):
|
||||
weight = conv.weight_v
|
||||
else:
|
||||
weight = conv.weight
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(0, 2, 1).flatten(1)
|
||||
target_density, block_size = params
|
||||
density = alpha + (1 - alpha) * target_density
|
||||
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
|
||||
w = w.reshape(i, k, o).permute(0, 2, 1)
|
||||
weight[:] = w
|
||||
|
||||
if self.last_mask is not None:
|
||||
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in conv.weight")
|
||||
|
||||
self.last_mask = new_mask
|
||||
|
||||
if verbose:
|
||||
print(f"conv1d_sparsier[{self.step_counter}]: {density=}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
import torch
|
||||
conv = torch.nn.Conv1d(8, 16, 8)
|
||||
params = (0.2, [8, 4])
|
||||
|
||||
sparsifier = Conv1dSparsifier([(conv, params)], 0, 100, 5)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(conv.weight)
|
||||
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix, debug
|
||||
|
||||
|
||||
class ConvTranspose1dSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (conv1d, params), where conv1d is an instance
|
||||
of torch.nn.Conv1d and params is a tuple (density, [m, n]),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> conv = torch.nn.ConvTranspose1d(8, 16, 8)
|
||||
>>> params = (0.2, [8, 4])
|
||||
>>> sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_mask = None
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for conv, params in self.task_list:
|
||||
# reshape weight
|
||||
if hasattr(conv, 'weight_v'):
|
||||
weight = conv.weight_v
|
||||
else:
|
||||
weight = conv.weight
|
||||
i, o, k = weight.shape
|
||||
w = weight.permute(2, 1, 0).reshape(k * o, i)
|
||||
target_density, block_size = params
|
||||
density = alpha + (1 - alpha) * target_density
|
||||
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
|
||||
w = w.reshape(k, o, i).permute(2, 1, 0)
|
||||
weight[:] = w
|
||||
|
||||
if self.last_mask is not None:
|
||||
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in conv.weight")
|
||||
|
||||
self.last_mask = new_mask
|
||||
|
||||
if verbose:
|
||||
print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
import torch
|
||||
conv = torch.nn.ConvTranspose1d(8, 16, 4, 4)
|
||||
params = (0.2, [8, 4])
|
||||
|
||||
sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 5)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(conv.weight)
|
||||
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix, debug
|
||||
|
||||
|
||||
class GRUSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
|
||||
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
|
||||
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
|
||||
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
|
||||
should be kept.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> gru = torch.nn.GRU(10, 20)
|
||||
>>> sparsify_dict = {
|
||||
... 'W_ir' : (0.5, [2, 2], False),
|
||||
... 'W_iz' : (0.6, [2, 2], False),
|
||||
... 'W_in' : (0.7, [2, 2], False),
|
||||
... 'W_hr' : (0.1, [4, 4], True),
|
||||
... 'W_hz' : (0.2, [4, 4], True),
|
||||
... 'W_hn' : (0.3, [4, 4], True),
|
||||
... }
|
||||
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for gru, params in self.task_list:
|
||||
hidden_size = gru.hidden_size
|
||||
|
||||
# input weights
|
||||
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
|
||||
if key in params:
|
||||
if hasattr(gru, 'weight_ih_l0_v'):
|
||||
weight = gru.weight_ih_l0_v
|
||||
else:
|
||||
weight = gru.weight_ih_l0
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
|
||||
weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
weight[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density, # density
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in weight_ih_l0_v")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
# recurrent weights
|
||||
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
|
||||
if key in params:
|
||||
if hasattr(gru, 'weight_hh_l0_v'):
|
||||
weight = gru.weight_hh_l0_v
|
||||
else:
|
||||
weight = gru.weight_hh_l0
|
||||
density = alpha + (1 - alpha) * params[key][0]
|
||||
if verbose:
|
||||
print(f"[{self.step_counter}]: {key} density: {density}")
|
||||
weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
|
||||
weight[i * hidden_size : (i + 1) * hidden_size, : ],
|
||||
density,
|
||||
params[key][1], # block_size
|
||||
params[key][2], # keep_diagonal (might want to set this to False)
|
||||
return_mask=True
|
||||
)
|
||||
|
||||
if type(self.last_masks[key]) != type(None):
|
||||
if not torch.all(self.last_masks[key] * new_mask == new_mask) and True:
|
||||
print("weight resurrection in weight_hh_l0_v")
|
||||
|
||||
self.last_masks[key] = new_mask
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
gru = torch.nn.GRU(10, 20)
|
||||
sparsify_dict = {
|
||||
'W_ir' : (0.5, [2, 2], False),
|
||||
'W_iz' : (0.6, [2, 2], False),
|
||||
'W_in' : (0.7, [2, 2], False),
|
||||
'W_hr' : (0.1, [4, 4], True),
|
||||
'W_hz' : (0.2, [4, 4], True),
|
||||
'W_hn' : (0.3, [4, 4], True),
|
||||
}
|
||||
|
||||
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(gru.weight_hh_l0)
|
||||
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .base_sparsifier import BaseSparsifier
|
||||
from .common import sparsify_matrix
|
||||
|
||||
|
||||
class LinearSparsifier(BaseSparsifier):
|
||||
def __init__(self, task_list, start, stop, interval, exponent=3):
|
||||
""" Sparsifier for torch.nn.GRUs
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task_list : list
|
||||
task_list contains a list of tuples (linear, params), where linear is an instance
|
||||
of torch.nn.Linear and params is a tuple (density, [m, n]),
|
||||
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
|
||||
sparsification is applied.
|
||||
|
||||
start : int
|
||||
training step after which sparsification will be started.
|
||||
|
||||
stop : int
|
||||
training step after which sparsification will be completed.
|
||||
|
||||
interval : int
|
||||
sparsification interval for steps between start and stop. After stop sparsification will be
|
||||
carried out after every call to GRUSparsifier.step()
|
||||
|
||||
exponent : float
|
||||
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
|
||||
with density (alpha + target_density * (1 * alpha)), where
|
||||
alpha = ((stop - i) / (start - stop)) ** exponent
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> import torch
|
||||
>>> linear = torch.nn.Linear(8, 16)
|
||||
>>> params = (0.2, [8, 4])
|
||||
>>> sparsifier = LinearSparsifier([(linear, params)], 0, 100, 50)
|
||||
>>> for i in range(100):
|
||||
... sparsifier.step()
|
||||
"""
|
||||
|
||||
super().__init__(task_list, start, stop, interval, exponent=3)
|
||||
|
||||
self.last_mask = None
|
||||
|
||||
def sparsify(self, alpha, verbose=False):
|
||||
""" carries out sparsification step
|
||||
|
||||
Call this function after optimizer.step in your
|
||||
training loop.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
alpha : float
|
||||
density interpolation parameter (1: dense, 0: target density)
|
||||
verbose : bool
|
||||
if true, densities are printed out
|
||||
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for linear, params in self.task_list:
|
||||
if hasattr(linear, 'weight_v'):
|
||||
weight = linear.weight_v
|
||||
else:
|
||||
weight = linear.weight
|
||||
target_density, block_size = params
|
||||
density = alpha + (1 - alpha) * target_density
|
||||
weight[:], new_mask = sparsify_matrix(weight, density, block_size, return_mask=True)
|
||||
|
||||
if self.last_mask is not None:
|
||||
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
|
||||
print("weight resurrection in conv.weight")
|
||||
|
||||
self.last_mask = new_mask
|
||||
|
||||
if verbose:
|
||||
print(f"linear_sparsifier[{self.step_counter}]: {density=}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing sparsifier")
|
||||
|
||||
import torch
|
||||
linear = torch.nn.Linear(8, 16)
|
||||
params = (0.2, [4, 2])
|
||||
|
||||
sparsifier = LinearSparsifier([(linear, params)], 0, 100, 5)
|
||||
|
||||
for i in range(100):
|
||||
sparsifier.step(verbose=True)
|
||||
|
||||
print(linear.weight)
|
||||
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
|
||||
from dnntools.sparsification import GRUSparsifier, LinearSparsifier, Conv1dSparsifier, ConvTranspose1dSparsifier
|
||||
|
||||
def mark_for_sparsification(module, params):
|
||||
setattr(module, 'sparsify', True)
|
||||
setattr(module, 'sparsification_params', params)
|
||||
return module
|
||||
|
||||
def create_sparsifier(module, start, stop, interval):
|
||||
sparsifier_list = []
|
||||
for m in module.modules():
|
||||
if hasattr(m, 'sparsify'):
|
||||
if isinstance(m, torch.nn.GRU):
|
||||
sparsifier_list.append(
|
||||
GRUSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
elif isinstance(m, torch.nn.Linear):
|
||||
sparsifier_list.append(
|
||||
LinearSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
elif isinstance(m, torch.nn.Conv1d):
|
||||
sparsifier_list.append(
|
||||
Conv1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
elif isinstance(m, torch.nn.ConvTranspose1d):
|
||||
sparsifier_list.append(
|
||||
ConvTranspose1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
|
||||
)
|
||||
else:
|
||||
print(f"[create_sparsifier] warning: module {m} marked for sparsification but no suitable sparsifier exists.")
|
||||
|
||||
def sparsify(verbose=False):
|
||||
for sparsifier in sparsifier_list:
|
||||
sparsifier.step(verbose)
|
||||
|
||||
return sparsify
|
||||
|
||||
|
||||
def count_parameters(model, verbose=False):
|
||||
total = 0
|
||||
for name, p in model.named_parameters():
|
||||
count = torch.ones_like(p).sum().item()
|
||||
|
||||
if verbose:
|
||||
print(f"{name}: {count} parameters")
|
||||
|
||||
total += count
|
||||
|
||||
return total
|
||||
|
||||
def estimate_nonzero_parameters(module):
|
||||
num_zero_parameters = 0
|
||||
if hasattr(module, 'sparsify'):
|
||||
params = module.sparsification_params
|
||||
if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.ConvTranspose1d):
|
||||
num_zero_parameters = torch.ones_like(module.weight).sum().item() * (1 - params[0])
|
||||
elif isinstance(module, torch.nn.GRU):
|
||||
num_zero_parameters = module.input_size * module.hidden_size * (3 - params['W_ir'][0] - params['W_iz'][0] - params['W_in'][0])
|
||||
num_zero_parameters += module.hidden_size * module.hidden_size * (3 - params['W_hr'][0] - params['W_hz'][0] - params['W_hn'][0])
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
num_zero_parameters = module.in_features * module.out_features * params[0]
|
||||
else:
|
||||
raise ValueError(f'unknown sparsification method for module of type {type(module)}')
|
||||
@@ -0,0 +1 @@
|
||||
torch
|
||||
48
managed_components/78__esp-opus/dnn/torch/dnntools/setup.py
Normal file
48
managed_components/78__esp-opus/dnn/torch/dnntools/setup.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
/* Copyright (c) 2023 Amazon
|
||||
Written by Jan Buethe */
|
||||
/*
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
"""
|
||||
|
||||
#!/usr/bin/env/python
|
||||
import os
|
||||
from setuptools import setup
|
||||
|
||||
lib_folder = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
with open(os.path.join(lib_folder, 'requirements.txt'), 'r') as f:
|
||||
install_requires = list(f.read().splitlines())
|
||||
|
||||
print(install_requires)
|
||||
|
||||
setup(name='dnntools',
|
||||
version='1.0',
|
||||
author='Jan Buethe',
|
||||
author_email='jbuethe@amazon.de',
|
||||
description='Non-Standard tools for deep neural network training with PyTorch',
|
||||
packages=['dnntools', 'dnntools.sparsification', 'dnntools.quantization'],
|
||||
install_requires=install_requires
|
||||
)
|
||||
Reference in New Issue
Block a user