PyTorch
For PyTorch, I would recommend using accelerate. Hugging Face :hugs: has a great tutorial on how to use it here.
However, if you wish to use native PyTorch and you are implementing your own training loop, you could do something like this:
# batch accumulation parameter
accum_iter = 4
# loop through enumaretad batches
for batch_idx, (inputs, labels) in enumerate(data_loader):
# extract inputs and labels
inputs = inputs.to(device)
labels = labels.to(device)
# passes and weights update
with torch.set_grad_enabled(True):
# forward pass
preds = model(inputs)
loss = criterion(preds, labels)
# scale loss prior to accumulation
loss = loss / accum_iter
# backward pass
loss.backward()
# weights update and reset gradients
if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):
optimizer.step()
optimizer.zero_grad()