Skip to content

Commit

Permalink
NVFuser fixes for 23.11 Modulus release (#197)
Browse files Browse the repository at this point in the history
* update to newer nvfuser api from 23.10 container

* improve error handling for nvfuser
  • Loading branch information
ktangsali authored Oct 24, 2023
1 parent ed10b4c commit 97f79e5
Show file tree
Hide file tree
Showing 9 changed files with 319 additions and 234 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Dependencies

- Updated the base container to PyTorch 23.10.

## [0.3.0] - 2023-09-21

### Added
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

ARG BASE_CONTAINER=nvcr.io/nvidia/pytorch:23.07-py3
ARG BASE_CONTAINER=nvcr.io/nvidia/pytorch:23.10-py3
FROM ${BASE_CONTAINER} as builder

ARG TARGETPLATFORM
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ Modulus has many optional dependencies that are used in specific components.
When using pip, all dependencies used in Modulus can be installed with
`pip install modulus[all]`. If you are developing Modulus, developer dependencies
can be installed using `pip install modulus[dev]`. Otherwise, additional dependencies
can be installed on a case by case basis.
can be installed on a case by case basis. A detailed information on installing the
optional dependencies can be found in the
[Getting Started Guide](https://docs.nvidia.com/deeplearning/modulus/getting-started/index.html).

### NVCR Container

Expand Down
20 changes: 16 additions & 4 deletions modulus/models/gnn_layers/mesh_graph_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from torch import Tensor
from torch.autograd.function import once_differentiable

from modulus.models.layers.fused_silu import silu_backward_for

from .utils import CuGraphCSC, concat_efeat, sum_efeat

try:
Expand Down Expand Up @@ -57,6 +55,11 @@ def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor],]:
"""backward pass of the SiLU + Linear function"""

from nvfuser import FusionDefinition

from modulus.models.layers.fused_silu import silu_backward_for

(
need_dgrad,
need_wgrad,
Expand All @@ -77,8 +80,17 @@ def backward(

if need_dgrad:
grad_features = grad_output @ weight
silu_backward = silu_backward_for(features.dtype, features.dim())
grad_silu = silu_backward.execute([features])[0]

with FusionDefinition() as fd:
silu_backward_for(
fd,
features.dtype,
features.dim(),
features.size(),
features.stride(),
)

grad_silu = fd.execute([features])[0]
grad_features = grad_features * grad_silu

return grad_features, grad_weight, grad_bias
Expand Down
208 changes: 122 additions & 86 deletions modulus/models/layers/fused_silu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,28 @@
# limitations under the License.

import functools
import logging
from typing import Tuple

import torch
from torch.autograd import Function

logger = logging.getLogger(__name__)

try:
from nvfuser._C import DataType, Fusion, FusionDefinition
import nvfuser
from nvfuser import DataType, FusionDefinition
except ImportError:
# accomodating for earlier versions of PyTorch (< 2.0)
# which don't need nvfuser as explicit dependency
from torch._C._nvfuser import DataType, Fusion, FusionDefinition
logger.error(
"An error occured. Either nvfuser is not installed or the version is "
"incompatible. Please retry after installing correct version of nvfuser. "
"The new version of nvfuser should be available in PyTorch container version "
">= 23.10. "
"https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html. "
"If using a source install method, please refer nvFuser repo for installation "
"guidelines https://github.com/NVIDIA/Fuser.",
)
raise

_torch_dtype_to_nvfuser = {
torch.double: DataType.Double,
Expand All @@ -38,149 +50,167 @@


@functools.lru_cache(maxsize=None)
def silu_backward_for(dtype: torch.dtype, dim: int): # pragma: no cover
def silu_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
"""
nvfuser frontend implmentation of SiLU backward as a fused kernel and with
activations recomputation
Parameters
----------
fd : FusionDefition
nvFuser's FusionDefition class
dtype : torch.dtype
Data type to use for the implementation
dim : int
Dimension of the input tensor
Returns
-------
fusion :
An nvfuser fused executor for SiLU backward
size : torch.Size
Size of the input tensor
stride : Tuple[int, ...]
Stride of the input tensor
"""
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")

fusion = Fusion()

with FusionDefinition(fusion) as fd:
x = fd.define_tensor(dim, dtype)
one = fd.define_constant(1.0)
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)

# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# z = sigmoid(x)
grad_input = fd.ops.mul(y, fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y))))
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# z = sigmoid(x)
grad_input = fd.ops.mul(y, fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y))))

grad_input = fd.ops.cast(grad_input, dtype)
grad_input = fd.ops.cast(grad_input, dtype)

fd.add_output(grad_input)

return fusion
fd.add_output(grad_input)


@functools.lru_cache(maxsize=None)
def silu_double_backward_for(dtype: torch.dtype, dim: int): # pragma: no cover
def silu_double_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
"""
nvfuser frontend implmentation of SiLU double backward as a fused kernel and with
activations recomputation
Parameters
----------
fd : FusionDefition
nvFuser's FusionDefition class
dtype : torch.dtype
Data type to use for the implementation
dim : int
Dimension of the input tensor
Returns
-------
fusion :
An nvfuser fused executor for SiLU backward
size : torch.Size
Size of the input tensor
stride : Tuple[int, ...]
Stride of the input tensor
"""
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")

fusion = Fusion()
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)

with FusionDefinition(fusion) as fd:
x = fd.define_tensor(dim, dtype)
one = fd.define_constant(1.0)
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# z = 1 + x * (1 - y)
z = fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y)))
# term1 = dy * z
term1 = fd.ops.mul(dy, z)

# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# z = 1 + x * (1 - y)
z = fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y)))
# term1 = dy * z
term1 = fd.ops.mul(dy, z)
# term2 = y * ((1 - y) - x * dy)
term2 = fd.ops.mul(y, fd.ops.sub(fd.ops.sub(one, y), fd.ops.mul(x, dy)))

# term2 = y * ((1 - y) - x * dy)
term2 = fd.ops.mul(y, fd.ops.sub(fd.ops.sub(one, y), fd.ops.mul(x, dy)))
grad_input = fd.ops.add(term1, term2)

grad_input = fd.ops.add(term1, term2)
grad_input = fd.ops.cast(grad_input, dtype)

grad_input = fd.ops.cast(grad_input, dtype)

fd.add_output(grad_input)

return fusion
fd.add_output(grad_input)


@functools.lru_cache(maxsize=None)
def silu_triple_backward_for(dtype: torch.dtype, dim: int): # pragma: no cover
def silu_triple_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
"""
nvfuser frontend implmentation of SiLU triple backward as a fused kernel and with
activations recomputation
Parameters
----------
fd : FusionDefition
nvFuser's FusionDefition class
dtype : torch.dtype
Data type to use for the implementation
dim : int
Dimension of the input tensor
Returns
-------
fusion :
An nvfuser fused executor for SiLU backward
size : torch.Size
Size of the input tensor
stride : Tuple[int, ...]
Stride of the input tensor
"""
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")

fusion = Fusion()

with FusionDefinition(fusion) as fd:
x = fd.define_tensor(dim, dtype)
one = fd.define_constant(1.0)
two = fd.define_constant(2.0)

# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# ddy = (1 - 2y) * dy
ddy = fd.ops.mul(fd.ops.sub(one, fd.ops.mul(two, y)), dy)
# term1 = ddy * (2 + x - 2xy)
term1 = fd.ops.mul(
ddy, fd.ops.sub(fd.ops.add(two, x), fd.ops.mul(two, fd.ops.mul(x, y)))
)
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)
two = fd.define_constant(2.0)

# term2 = dy * (1 - 2 (y + x * dy))
term2 = fd.ops.mul(
dy, fd.ops.sub(one, fd.ops.mul(two, fd.ops.add(y, fd.ops.mul(x, dy))))
)
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# ddy = (1 - 2y) * dy
ddy = fd.ops.mul(fd.ops.sub(one, fd.ops.mul(two, y)), dy)
# term1 = ddy * (2 + x - 2xy)
term1 = fd.ops.mul(
ddy, fd.ops.sub(fd.ops.add(two, x), fd.ops.mul(two, fd.ops.mul(x, y)))
)

grad_input = fd.ops.add(term1, term2)
# term2 = dy * (1 - 2 (y + x * dy))
term2 = fd.ops.mul(
dy, fd.ops.sub(one, fd.ops.mul(two, fd.ops.add(y, fd.ops.mul(x, dy))))
)

grad_input = fd.ops.cast(grad_input, dtype)
grad_input = fd.ops.add(term1, term2)

fd.add_output(grad_input)
grad_input = fd.ops.cast(grad_input, dtype)

return fusion
fd.add_output(grad_input)


class FusedSiLU(Function):
Expand Down Expand Up @@ -237,8 +267,10 @@ class FusedSiLU_deriv_1(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
silu_backward = silu_backward_for(x.dtype, x.dim())
return silu_backward.execute([x])[0]
with FusionDefinition() as fd:
silu_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out

@staticmethod
def backward(ctx, grad_output): # pragma: no cover
Expand All @@ -255,8 +287,10 @@ class FusedSiLU_deriv_2(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
silu_double_backward = silu_double_backward_for(x.dtype, x.dim())
return silu_double_backward.execute([x])[0]
with FusionDefinition() as fd:
silu_double_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out

@staticmethod
def backward(ctx, grad_output): # pragma: no cover
Expand All @@ -273,8 +307,10 @@ class FusedSiLU_deriv_3(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
silu_triple_backward = silu_triple_backward_for(x.dtype, x.dim())
return silu_triple_backward.execute([x])[0]
with FusionDefinition() as fd:
silu_triple_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out

@staticmethod
def backward(ctx, grad_output): # pragma: no cover
Expand Down
9 changes: 9 additions & 0 deletions test/models/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,12 @@ def compare_output(
return False

return True


def is_fusion_available(cls_name: str):
"""Check if certain APIs are available in nvfuser package."""

try:
return hasattr(__import__("nvfuser", fromlist=[cls_name]), cls_name)
except ImportError:
return False
Loading

0 comments on commit 97f79e5

Please sign in to comment.