a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
4.1 KiB

2 years ago
from typing import Callable, Iterable, Sequence, Union
import torch
from torch.cuda.amp import custom_bwd, custom_fwd
def checkpoint(
func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
inputs: Sequence[torch.Tensor],
params: Iterable[torch.Tensor],
flag: bool,
):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.length = length
input_tensors = list(args[:length])
input_params = list(args[length:])
ctx.save_for_backward(*input_tensors, *input_params)
with torch.no_grad():
output_tensors = ctx.run_function(*input_tensors)
return output_tensors
@staticmethod
@custom_bwd
def backward(ctx, *output_grads):
inputs = ctx.saved_tensors
input_tensors = inputs[: ctx.length]
input_params = inputs[ctx.length :]
res = CheckpointFunctionGradFunction.apply(
ctx.run_function,
len(input_tensors),
len(input_params),
*input_tensors,
*input_params,
*output_grads
)
return (None, None) + res
class CheckpointFunctionGradFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, run_function, length_1, length_2, *args):
ctx.run_function = run_function
ctx.length_1 = length_1
ctx.length_2 = length_2
input_tensors = [x.detach().requires_grad_(True) for x in args[:length_1]]
input_params = list(args[length_1 : length_1 + length_2])
output_grads = list(args[length_1 + length_2 :])
ctx.save_for_backward(*input_tensors, *input_params, *output_grads)
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
input_tensors + input_params,
output_grads,
allow_unused=True,
)
return input_grads
@staticmethod
@custom_bwd
def backward(ctx, *all_output_grads):
args = ctx.saved_tensors
input_tensors = [x.detach().requires_grad_(True) for x in args[: ctx.length_1]]
input_params = list(args[ctx.length_1 : ctx.length_1 + ctx.length_2])
output_grads = [
x.detach().requires_grad_(True) for x in args[ctx.length_1 + ctx.length_2 :]
]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
input_tensors + input_params,
output_grads,
allow_unused=True,
create_graph=True,
retain_graph=True,
)
input_grads_grads = torch.autograd.grad(
input_grads,
input_tensors + input_params + output_grads,
all_output_grads,
allow_unused=True,
)
del input_grads
return (None, None, None) + input_grads_grads