Mixed Precision

There has also been added experimental support for mixed precision:

from tensorflow.keras import mixed_precision
from tensorflow.keras.optimizers import Adam
from gradient_accumulator import GradientAccumulateModel

mixed_precision.set_global_policy('mixed_float16')
model = GradientAccumulateModel(accum_steps=4, mixed_precision=True, inputs=model.input, outputs=model.output)

opt = Adam(1e-3, epsilon=1e-4)
opt = mixed_precision.LossScaleOptimizer(opt)

If using TPUs, use bfloat16 instead of float16, like so:

mixed_precision.set_global_policy('mixed_bfloat16')

There is also an example of how to use gradient accumulation with mixed precision here.