Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NVFuser fixes for 23.11 Modulus release #197

Merged
merged 12 commits into from
Oct 24, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Dependencies

- Updated the base container to PyTorch 23.10.

## [0.3.0] - 2023-09-21

### Added
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

ARG BASE_CONTAINER=nvcr.io/nvidia/pytorch:23.07-py3
ARG BASE_CONTAINER=nvcr.io/nvidia/pytorch:23.10-py3
FROM ${BASE_CONTAINER} as builder

ARG TARGETPLATFORM
Expand Down
20 changes: 16 additions & 4 deletions modulus/models/gnn_layers/mesh_graph_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from torch import Tensor
from torch.autograd.function import once_differentiable

from modulus.models.layers.fused_silu import silu_backward_for

from .utils import CuGraphCSC, concat_efeat, sum_efeat

try:
Expand Down Expand Up @@ -57,6 +55,11 @@ def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor],]:
"""backward pass of the SiLU + Linear function"""

from nvfuser import FusionDefinition
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved

from modulus.models.layers.fused_silu import silu_backward_for

(
need_dgrad,
need_wgrad,
Expand All @@ -77,8 +80,17 @@ def backward(

if need_dgrad:
grad_features = grad_output @ weight
silu_backward = silu_backward_for(features.dtype, features.dim())
grad_silu = silu_backward.execute([features])[0]

with FusionDefinition() as fd:
silu_backward_for(
fd,
features.dtype,
features.dim(),
features.size(),
features.stride(),
)

grad_silu = fd.execute([features])[0]
grad_features = grad_features * grad_silu

return grad_features, grad_weight, grad_bias
Expand Down
209 changes: 122 additions & 87 deletions modulus/models/layers/fused_silu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,27 @@
# limitations under the License.

import functools
import logging
from typing import Tuple

import torch
from torch.autograd import Function

logger = logging.getLogger(__name__)

try:
from nvfuser._C import DataType, Fusion, FusionDefinition
except ImportError:
# accomodating for earlier versions of PyTorch (< 2.0)
# which don't need nvfuser as explicit dependency
from torch._C._nvfuser import DataType, Fusion, FusionDefinition
import nvfuser
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
from nvfuser import DataType, FusionDefinition
except ImportError as e:
ktangsali marked this conversation as resolved.
Show resolved Hide resolved
logger.error(
"An error occured. Details: %s "
"Either nvfuser is not installed or the version is incompatible. "
"Please retry after installing correct version of nvfuser. "
"The new version of nvfuser should be available in PyTorch container version "
">= 23.10. "
"https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html",
e,
)

_torch_dtype_to_nvfuser = {
torch.double: DataType.Double,
Expand All @@ -38,149 +49,167 @@


@functools.lru_cache(maxsize=None)
def silu_backward_for(dtype: torch.dtype, dim: int): # pragma: no cover
def silu_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
"""
nvfuser frontend implmentation of SiLU backward as a fused kernel and with
activations recomputation

Parameters
----------
fd : FusionDefition
nvFuser's FusionDefition class
dtype : torch.dtype
Data type to use for the implementation
dim : int
Dimension of the input tensor

Returns
-------
fusion :
An nvfuser fused executor for SiLU backward
size : torch.Size
Size of the input tensor
stride : Tuple[int, ...]
Stride of the input tensor
"""
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")

fusion = Fusion()

with FusionDefinition(fusion) as fd:
x = fd.define_tensor(dim, dtype)
one = fd.define_constant(1.0)
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)

# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# z = sigmoid(x)
grad_input = fd.ops.mul(y, fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y))))
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# z = sigmoid(x)
grad_input = fd.ops.mul(y, fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y))))

grad_input = fd.ops.cast(grad_input, dtype)
grad_input = fd.ops.cast(grad_input, dtype)

fd.add_output(grad_input)

return fusion
fd.add_output(grad_input)


@functools.lru_cache(maxsize=None)
def silu_double_backward_for(dtype: torch.dtype, dim: int): # pragma: no cover
def silu_double_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
"""
nvfuser frontend implmentation of SiLU double backward as a fused kernel and with
activations recomputation

Parameters
----------
fd : FusionDefition
nvFuser's FusionDefition class
dtype : torch.dtype
Data type to use for the implementation
dim : int
Dimension of the input tensor

Returns
-------
fusion :
An nvfuser fused executor for SiLU backward
size : torch.Size
Size of the input tensor
stride : Tuple[int, ...]
Stride of the input tensor
"""
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")

fusion = Fusion()
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)

with FusionDefinition(fusion) as fd:
x = fd.define_tensor(dim, dtype)
one = fd.define_constant(1.0)
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# z = 1 + x * (1 - y)
z = fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y)))
# term1 = dy * z
term1 = fd.ops.mul(dy, z)

# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# z = 1 + x * (1 - y)
z = fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y)))
# term1 = dy * z
term1 = fd.ops.mul(dy, z)
# term2 = y * ((1 - y) - x * dy)
term2 = fd.ops.mul(y, fd.ops.sub(fd.ops.sub(one, y), fd.ops.mul(x, dy)))

# term2 = y * ((1 - y) - x * dy)
term2 = fd.ops.mul(y, fd.ops.sub(fd.ops.sub(one, y), fd.ops.mul(x, dy)))
grad_input = fd.ops.add(term1, term2)

grad_input = fd.ops.add(term1, term2)
grad_input = fd.ops.cast(grad_input, dtype)

grad_input = fd.ops.cast(grad_input, dtype)

fd.add_output(grad_input)

return fusion
fd.add_output(grad_input)


@functools.lru_cache(maxsize=None)
def silu_triple_backward_for(dtype: torch.dtype, dim: int): # pragma: no cover
def silu_triple_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
"""
nvfuser frontend implmentation of SiLU triple backward as a fused kernel and with
activations recomputation

Parameters
----------
fd : FusionDefition
nvFuser's FusionDefition class
dtype : torch.dtype
Data type to use for the implementation
dim : int
Dimension of the input tensor

Returns
-------
fusion :
An nvfuser fused executor for SiLU backward
size : torch.Size
Size of the input tensor
stride : Tuple[int, ...]
Stride of the input tensor
"""
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")

fusion = Fusion()

with FusionDefinition(fusion) as fd:
x = fd.define_tensor(dim, dtype)
one = fd.define_constant(1.0)
two = fd.define_constant(2.0)

# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# ddy = (1 - 2y) * dy
ddy = fd.ops.mul(fd.ops.sub(one, fd.ops.mul(two, y)), dy)
# term1 = ddy * (2 + x - 2xy)
term1 = fd.ops.mul(
ddy, fd.ops.sub(fd.ops.add(two, x), fd.ops.mul(two, fd.ops.mul(x, y)))
)
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)
two = fd.define_constant(2.0)

# term2 = dy * (1 - 2 (y + x * dy))
term2 = fd.ops.mul(
dy, fd.ops.sub(one, fd.ops.mul(two, fd.ops.add(y, fd.ops.mul(x, dy))))
)
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# ddy = (1 - 2y) * dy
ddy = fd.ops.mul(fd.ops.sub(one, fd.ops.mul(two, y)), dy)
# term1 = ddy * (2 + x - 2xy)
term1 = fd.ops.mul(
ddy, fd.ops.sub(fd.ops.add(two, x), fd.ops.mul(two, fd.ops.mul(x, y)))
)

grad_input = fd.ops.add(term1, term2)
# term2 = dy * (1 - 2 (y + x * dy))
term2 = fd.ops.mul(
dy, fd.ops.sub(one, fd.ops.mul(two, fd.ops.add(y, fd.ops.mul(x, dy))))
)

grad_input = fd.ops.cast(grad_input, dtype)
grad_input = fd.ops.add(term1, term2)

fd.add_output(grad_input)
grad_input = fd.ops.cast(grad_input, dtype)

return fusion
fd.add_output(grad_input)


class FusedSiLU(Function):
Expand Down Expand Up @@ -237,8 +266,10 @@ class FusedSiLU_deriv_1(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
silu_backward = silu_backward_for(x.dtype, x.dim())
return silu_backward.execute([x])[0]
with FusionDefinition() as fd:
silu_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out

@staticmethod
def backward(ctx, grad_output): # pragma: no cover
Expand All @@ -255,8 +286,10 @@ class FusedSiLU_deriv_2(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
silu_double_backward = silu_double_backward_for(x.dtype, x.dim())
return silu_double_backward.execute([x])[0]
with FusionDefinition() as fd:
silu_double_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out

@staticmethod
def backward(ctx, grad_output): # pragma: no cover
Expand All @@ -273,8 +306,10 @@ class FusedSiLU_deriv_3(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
silu_triple_backward = silu_triple_backward_for(x.dtype, x.dim())
return silu_triple_backward.execute([x])[0]
with FusionDefinition() as fd:
silu_triple_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out

@staticmethod
def backward(ctx, grad_output): # pragma: no cover
Expand Down
Loading