Gradient Accumulation Wrappers
|
Model wrapper for gradient accumulation. |
|
Optimizer wrapper for gradient accumulation. |
GradientAccumulateModel
GradientAccumulateOptimizer
- class gradient_accumulator.GradientAccumulateOptimizer(optimizer='SGD', accum_steps=1, reduction: str = 'MEAN', name: str = 'GradientAccumulateOptimizer', **kwargs)[source]
Optimizer wrapper for gradient accumulation.
- apply_gradients(grads_and_vars, name=None, **kwargs)[source]
Updates weights using gradients.
- Parameters:
grads_and_vars – dict containing variables and corresponding gradients.
name – name to set when applying gradients.
**kwargs – keyword arguments.
- Returns:
Updated weights.
- classmethod from_config(config, custom_objects=None)[source]
Gets config of original optimizer and deserializes it.
- property gradients[source]
The accumulated gradients on the current replica.
- Returns:
Current gradients in optimizer.
- property iterations[source]
Returns current iteration value of optimizer.
- Returns:
iterations of optimizer.
Custom Layers
AccumBatchNormalization
- class gradient_accumulator.AccumBatchNormalization(*args, **kwargs)[source]
Custom Batch Normaliztion layer with gradient accumulation support.
- build(input_shape)[source]
Builds layer and variables.
- Parameters:
input_shape – input feature map size.
- call(inputs, training=None, mask=None)[source]
Performs the batch normalization step.
- Parameters:
inputs – input feature map to apply batch normalization across.
training – whether layer should be in training mode or not.
mask – whether to calculate statistics within masked region of feature map.
- Returns:
Normalized feature map.
- get_moving_average(statistic, new_value)[source]
Returns the moving average given a statistic and current estimate.
- Parameters:
statistic – summary statistic e.g. average across for single feature over multiple samples
new_value – statistic of single feature for single forward step.
- Returns:
Updated statistic.
General Utilities
Adaptive Gradient Clipping
- gradient_accumulator.compute_norm(x, axis, keepdims)[source]
Computes the euclidean norm of a tensor \(x\).
- Parameters:
x – input tensor.
axis – which axis to compute norm across.
keepdims – whether to keep dimension after applying along axis.
- Returns:
Euclidean norm.
- gradient_accumulator.unitwise_norm(x)[source]
Wrapper class which dynamically sets axis and keepdims given an input x for calculating euclidean norm.
- Parameters:
x – input tensor.
- Returns:
Euclidean norm.
- gradient_accumulator.adaptive_clip_grad(parameters, gradients, clip_factor: float = 0.01, eps: float = 0.001)[source]
Performs adaptive gradient clipping on a given set of parameters and gradients.
Official JAX implementation (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets # noqa
Ross Wightman’s implementation https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/agc.py # noqa
- Parameters:
parameters – Which parameters to apply method on.
gradients – Which gradients to apply clipping on.
clip_factor – Sets upper limit for gradient clipping.
eps – Epsilon - small number in \(max()\) to avoid zero norm and preserve numerical stability.
- Returns:
Updated gradients after gradient clipping.