from typing import Callable, Iterable, Sequence, Union import torch from torch.cuda.amp import custom_bwd, custom_fwd def checkpoint( func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], inputs: Sequence[torch.Tensor], params: Iterable[torch.Tensor], flag: bool, ): """ Evaluate a function without caching intermediate activations, allowing for reduced memory at the expense of extra compute in the backward pass. :param func: the function to evaluate. :param inputs: the argument sequence to pass to `func`. :param params: a sequence of parameters `func` depends on but does not explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ if flag: args = tuple(inputs) + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: return func(*inputs) class CheckpointFunction(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.length = length input_tensors = list(args[:length]) input_params = list(args[length:]) ctx.save_for_backward(*input_tensors, *input_params) with torch.no_grad(): output_tensors = ctx.run_function(*input_tensors) return output_tensors @staticmethod @custom_bwd def backward(ctx, *output_grads): inputs = ctx.saved_tensors input_tensors = inputs[: ctx.length] input_params = inputs[ctx.length :] res = CheckpointFunctionGradFunction.apply( ctx.run_function, len(input_tensors), len(input_params), *input_tensors, *input_params, *output_grads ) return (None, None) + res class CheckpointFunctionGradFunction(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, run_function, length_1, length_2, *args): ctx.run_function = run_function ctx.length_1 = length_1 ctx.length_2 = length_2 input_tensors = [x.detach().requires_grad_(True) for x in args[:length_1]] input_params = list(args[length_1 : length_1 + length_2]) output_grads = list(args[length_1 + length_2 :]) ctx.save_for_backward(*input_tensors, *input_params, *output_grads) with torch.enable_grad(): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = [x.view_as(x) for x in input_tensors] output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, input_tensors + input_params, output_grads, allow_unused=True, ) return input_grads @staticmethod @custom_bwd def backward(ctx, *all_output_grads): args = ctx.saved_tensors input_tensors = [x.detach().requires_grad_(True) for x in args[: ctx.length_1]] input_params = list(args[ctx.length_1 : ctx.length_1 + ctx.length_2]) output_grads = [ x.detach().requires_grad_(True) for x in args[ctx.length_1 + ctx.length_2 :] ] with torch.enable_grad(): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = [x.view_as(x) for x in input_tensors] output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, input_tensors + input_params, output_grads, allow_unused=True, create_graph=True, retain_graph=True, ) input_grads_grads = torch.autograd.grad( input_grads, input_tensors + input_params + output_grads, all_output_grads, allow_unused=True, ) del input_grads return (None, None, None) + input_grads_grads