diff --git a/CHANGELOG.md b/CHANGELOG.md index b5a2dc5560..bb911c43b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Dependencies +- Updated the base container to PyTorch 23.10. + ## [0.3.0] - 2023-09-21 ### Added diff --git a/Dockerfile b/Dockerfile index 327adab53b..335b1e724f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG BASE_CONTAINER=nvcr.io/nvidia/pytorch:23.07-py3 +ARG BASE_CONTAINER=nvcr.io/nvidia/pytorch:23.10-py3 FROM ${BASE_CONTAINER} as builder ARG TARGETPLATFORM diff --git a/README.md b/README.md index 398e22a90a..708d965cd4 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,9 @@ Modulus has many optional dependencies that are used in specific components. When using pip, all dependencies used in Modulus can be installed with `pip install modulus[all]`. If you are developing Modulus, developer dependencies can be installed using `pip install modulus[dev]`. Otherwise, additional dependencies -can be installed on a case by case basis. +can be installed on a case by case basis. A detailed information on installing the +optional dependencies can be found in the +[Getting Started Guide](https://docs.nvidia.com/deeplearning/modulus/getting-started/index.html). ### NVCR Container diff --git a/modulus/models/gnn_layers/mesh_graph_mlp.py b/modulus/models/gnn_layers/mesh_graph_mlp.py index 8f566a0900..3faaa139bd 100644 --- a/modulus/models/gnn_layers/mesh_graph_mlp.py +++ b/modulus/models/gnn_layers/mesh_graph_mlp.py @@ -21,8 +21,6 @@ from torch import Tensor from torch.autograd.function import once_differentiable -from modulus.models.layers.fused_silu import silu_backward_for - from .utils import CuGraphCSC, concat_efeat, sum_efeat try: @@ -57,6 +55,11 @@ def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor],]: """backward pass of the SiLU + Linear function""" + + from nvfuser import FusionDefinition + + from modulus.models.layers.fused_silu import silu_backward_for + ( need_dgrad, need_wgrad, @@ -77,8 +80,17 @@ def backward( if need_dgrad: grad_features = grad_output @ weight - silu_backward = silu_backward_for(features.dtype, features.dim()) - grad_silu = silu_backward.execute([features])[0] + + with FusionDefinition() as fd: + silu_backward_for( + fd, + features.dtype, + features.dim(), + features.size(), + features.stride(), + ) + + grad_silu = fd.execute([features])[0] grad_features = grad_features * grad_silu return grad_features, grad_weight, grad_bias diff --git a/modulus/models/layers/fused_silu.py b/modulus/models/layers/fused_silu.py index 13eaac33cd..52e7a110b4 100644 --- a/modulus/models/layers/fused_silu.py +++ b/modulus/models/layers/fused_silu.py @@ -13,16 +13,28 @@ # limitations under the License. import functools +import logging +from typing import Tuple import torch from torch.autograd import Function +logger = logging.getLogger(__name__) + try: - from nvfuser._C import DataType, Fusion, FusionDefinition + import nvfuser + from nvfuser import DataType, FusionDefinition except ImportError: - # accomodating for earlier versions of PyTorch (< 2.0) - # which don't need nvfuser as explicit dependency - from torch._C._nvfuser import DataType, Fusion, FusionDefinition + logger.error( + "An error occured. Either nvfuser is not installed or the version is " + "incompatible. Please retry after installing correct version of nvfuser. " + "The new version of nvfuser should be available in PyTorch container version " + ">= 23.10. " + "https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html. " + "If using a source install method, please refer nvFuser repo for installation " + "guidelines https://github.com/NVIDIA/Fuser.", + ) + raise _torch_dtype_to_nvfuser = { torch.double: DataType.Double, @@ -38,149 +50,167 @@ @functools.lru_cache(maxsize=None) -def silu_backward_for(dtype: torch.dtype, dim: int): # pragma: no cover +def silu_backward_for( + fd: FusionDefinition, + dtype: torch.dtype, + dim: int, + size: torch.Size, + stride: Tuple[int, ...], +): # pragma: no cover """ nvfuser frontend implmentation of SiLU backward as a fused kernel and with activations recomputation Parameters ---------- + fd : FusionDefition + nvFuser's FusionDefition class dtype : torch.dtype Data type to use for the implementation dim : int Dimension of the input tensor - - Returns - ------- - fusion : - An nvfuser fused executor for SiLU backward + size : torch.Size + Size of the input tensor + stride : Tuple[int, ...] + Stride of the input tensor """ try: dtype = _torch_dtype_to_nvfuser[dtype] except KeyError: raise TypeError("Unsupported dtype") - fusion = Fusion() - - with FusionDefinition(fusion) as fd: - x = fd.define_tensor(dim, dtype) - one = fd.define_constant(1.0) + x = fd.define_tensor( + shape=[-1] * dim, + contiguity=nvfuser.compute_contiguity(size, stride), + dtype=dtype, + ) + one = fd.define_constant(1.0) - # y = sigmoid(x) - y = fd.ops.sigmoid(x) - # z = sigmoid(x) - grad_input = fd.ops.mul(y, fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y)))) + # y = sigmoid(x) + y = fd.ops.sigmoid(x) + # z = sigmoid(x) + grad_input = fd.ops.mul(y, fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y)))) - grad_input = fd.ops.cast(grad_input, dtype) + grad_input = fd.ops.cast(grad_input, dtype) - fd.add_output(grad_input) - - return fusion + fd.add_output(grad_input) @functools.lru_cache(maxsize=None) -def silu_double_backward_for(dtype: torch.dtype, dim: int): # pragma: no cover +def silu_double_backward_for( + fd: FusionDefinition, + dtype: torch.dtype, + dim: int, + size: torch.Size, + stride: Tuple[int, ...], +): # pragma: no cover """ nvfuser frontend implmentation of SiLU double backward as a fused kernel and with activations recomputation Parameters ---------- + fd : FusionDefition + nvFuser's FusionDefition class dtype : torch.dtype Data type to use for the implementation dim : int Dimension of the input tensor - - Returns - ------- - fusion : - An nvfuser fused executor for SiLU backward + size : torch.Size + Size of the input tensor + stride : Tuple[int, ...] + Stride of the input tensor """ try: dtype = _torch_dtype_to_nvfuser[dtype] except KeyError: raise TypeError("Unsupported dtype") - fusion = Fusion() + x = fd.define_tensor( + shape=[-1] * dim, + contiguity=nvfuser.compute_contiguity(size, stride), + dtype=dtype, + ) + one = fd.define_constant(1.0) - with FusionDefinition(fusion) as fd: - x = fd.define_tensor(dim, dtype) - one = fd.define_constant(1.0) + # y = sigmoid(x) + y = fd.ops.sigmoid(x) + # dy = y * (1 - y) + dy = fd.ops.mul(y, fd.ops.sub(one, y)) + # z = 1 + x * (1 - y) + z = fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y))) + # term1 = dy * z + term1 = fd.ops.mul(dy, z) - # y = sigmoid(x) - y = fd.ops.sigmoid(x) - # dy = y * (1 - y) - dy = fd.ops.mul(y, fd.ops.sub(one, y)) - # z = 1 + x * (1 - y) - z = fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y))) - # term1 = dy * z - term1 = fd.ops.mul(dy, z) + # term2 = y * ((1 - y) - x * dy) + term2 = fd.ops.mul(y, fd.ops.sub(fd.ops.sub(one, y), fd.ops.mul(x, dy))) - # term2 = y * ((1 - y) - x * dy) - term2 = fd.ops.mul(y, fd.ops.sub(fd.ops.sub(one, y), fd.ops.mul(x, dy))) + grad_input = fd.ops.add(term1, term2) - grad_input = fd.ops.add(term1, term2) + grad_input = fd.ops.cast(grad_input, dtype) - grad_input = fd.ops.cast(grad_input, dtype) - - fd.add_output(grad_input) - - return fusion + fd.add_output(grad_input) @functools.lru_cache(maxsize=None) -def silu_triple_backward_for(dtype: torch.dtype, dim: int): # pragma: no cover +def silu_triple_backward_for( + fd: FusionDefinition, + dtype: torch.dtype, + dim: int, + size: torch.Size, + stride: Tuple[int, ...], +): # pragma: no cover """ nvfuser frontend implmentation of SiLU triple backward as a fused kernel and with activations recomputation Parameters ---------- + fd : FusionDefition + nvFuser's FusionDefition class dtype : torch.dtype Data type to use for the implementation dim : int Dimension of the input tensor - - Returns - ------- - fusion : - An nvfuser fused executor for SiLU backward + size : torch.Size + Size of the input tensor + stride : Tuple[int, ...] + Stride of the input tensor """ try: dtype = _torch_dtype_to_nvfuser[dtype] except KeyError: raise TypeError("Unsupported dtype") - fusion = Fusion() - - with FusionDefinition(fusion) as fd: - x = fd.define_tensor(dim, dtype) - one = fd.define_constant(1.0) - two = fd.define_constant(2.0) - - # y = sigmoid(x) - y = fd.ops.sigmoid(x) - # dy = y * (1 - y) - dy = fd.ops.mul(y, fd.ops.sub(one, y)) - # ddy = (1 - 2y) * dy - ddy = fd.ops.mul(fd.ops.sub(one, fd.ops.mul(two, y)), dy) - # term1 = ddy * (2 + x - 2xy) - term1 = fd.ops.mul( - ddy, fd.ops.sub(fd.ops.add(two, x), fd.ops.mul(two, fd.ops.mul(x, y))) - ) + x = fd.define_tensor( + shape=[-1] * dim, + contiguity=nvfuser.compute_contiguity(size, stride), + dtype=dtype, + ) + one = fd.define_constant(1.0) + two = fd.define_constant(2.0) - # term2 = dy * (1 - 2 (y + x * dy)) - term2 = fd.ops.mul( - dy, fd.ops.sub(one, fd.ops.mul(two, fd.ops.add(y, fd.ops.mul(x, dy)))) - ) + # y = sigmoid(x) + y = fd.ops.sigmoid(x) + # dy = y * (1 - y) + dy = fd.ops.mul(y, fd.ops.sub(one, y)) + # ddy = (1 - 2y) * dy + ddy = fd.ops.mul(fd.ops.sub(one, fd.ops.mul(two, y)), dy) + # term1 = ddy * (2 + x - 2xy) + term1 = fd.ops.mul( + ddy, fd.ops.sub(fd.ops.add(two, x), fd.ops.mul(two, fd.ops.mul(x, y))) + ) - grad_input = fd.ops.add(term1, term2) + # term2 = dy * (1 - 2 (y + x * dy)) + term2 = fd.ops.mul( + dy, fd.ops.sub(one, fd.ops.mul(two, fd.ops.add(y, fd.ops.mul(x, dy)))) + ) - grad_input = fd.ops.cast(grad_input, dtype) + grad_input = fd.ops.add(term1, term2) - fd.add_output(grad_input) + grad_input = fd.ops.cast(grad_input, dtype) - return fusion + fd.add_output(grad_input) class FusedSiLU(Function): @@ -237,8 +267,10 @@ class FusedSiLU_deriv_1(Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - silu_backward = silu_backward_for(x.dtype, x.dim()) - return silu_backward.execute([x])[0] + with FusionDefinition() as fd: + silu_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride()) + out = fd.execute([x])[0] + return out @staticmethod def backward(ctx, grad_output): # pragma: no cover @@ -255,8 +287,10 @@ class FusedSiLU_deriv_2(Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - silu_double_backward = silu_double_backward_for(x.dtype, x.dim()) - return silu_double_backward.execute([x])[0] + with FusionDefinition() as fd: + silu_double_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride()) + out = fd.execute([x])[0] + return out @staticmethod def backward(ctx, grad_output): # pragma: no cover @@ -273,8 +307,10 @@ class FusedSiLU_deriv_3(Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - silu_triple_backward = silu_triple_backward_for(x.dtype, x.dim()) - return silu_triple_backward.execute([x])[0] + with FusionDefinition() as fd: + silu_triple_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride()) + out = fd.execute([x])[0] + return out @staticmethod def backward(ctx, grad_output): # pragma: no cover diff --git a/test/models/common/utils.py b/test/models/common/utils.py index 812e528413..284e50a5ef 100644 --- a/test/models/common/utils.py +++ b/test/models/common/utils.py @@ -94,3 +94,12 @@ def compare_output( return False return True + + +def is_fusion_available(cls_name: str): + """Check if certain APIs are available in nvfuser package.""" + + try: + return hasattr(__import__("nvfuser", fromlist=[cls_name]), cls_name) + except ImportError: + return False diff --git a/test/models/graphcast/test_concat_trick.py b/test/models/graphcast/test_concat_trick.py index dbe2f556e3..3904815c1d 100644 --- a/test/models/graphcast/test_concat_trick.py +++ b/test/models/graphcast/test_concat_trick.py @@ -11,21 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys -import numpy as np -import torch -from pytest_utils import import_or_fail -from utils import fix_random_seeds, get_icosphere_path +script_path = os.path.abspath(__file__) +sys.path.append(os.path.join(os.path.dirname(script_path), "..")) + +import common # noqa: E402 +import numpy as np # noqa: E402 +import pytest # noqa: E402 +import torch # noqa: E402 +from pytest_utils import import_or_fail # noqa: E402 +from utils import fix_random_seeds, get_icosphere_path # noqa: E402 icosphere_path = get_icosphere_path() @import_or_fail("dgl") -def test_concat_trick(pytestconfig, num_channels=2, res_h=11, res_w=20): +@pytest.mark.parametrize("recomp_act", [False, True]) +def test_concat_trick(pytestconfig, recomp_act, num_channels=2, res_h=11, res_w=20): """Test concat trick""" from modulus.models.graphcast.graph_cast_net import GraphCastNet + if recomp_act and not common.utils.is_fusion_available("FusionDefinition"): + pytest.skip("nvfuser module is not available or has incorrect version") + # Fix random seeds fix_random_seeds() @@ -34,70 +45,69 @@ def test_concat_trick(pytestconfig, num_channels=2, res_h=11, res_w=20): x = torch.rand(1, num_channels, res_h, res_w, device=device) x_ct = x.clone().detach() - for recomp_act in [False, True]: - # Fix random seeds - torch.manual_seed(0) - torch.cuda.manual_seed(0) - np.random.seed(0) - - # Instantiate the model - model = GraphCastNet( - meshgraph_path=icosphere_path, - static_dataset_path=None, - input_res=(res_h, res_w), - input_dim_grid_nodes=num_channels, - input_dim_mesh_nodes=3, - input_dim_edges=4, - output_dim_grid_nodes=num_channels, - processor_layers=3, - hidden_dim=4, - do_concat_trick=False, - recompute_activation=False, - ).to("cuda") - - # Fix random seeds again - fix_random_seeds() - - # Instantiate the model with concat trick enabled - model_ct = GraphCastNet( - meshgraph_path=icosphere_path, - static_dataset_path=None, - input_res=(res_h, res_w), - input_dim_grid_nodes=num_channels, - input_dim_mesh_nodes=3, - input_dim_edges=4, - output_dim_grid_nodes=num_channels, - processor_layers=3, - hidden_dim=4, - do_concat_trick=True, - recompute_activation=recomp_act, - ).to(device) - - # Forward pass without checkpointing - x.requires_grad_() - y_pred = model(x) - loss = y_pred.sum() - loss.backward() - x_grad = x.grad - - x_ct.requires_grad_() - y_pred_ct = model_ct(x_ct) - loss_ct = y_pred_ct.sum() - loss_ct.backward() - x_grad_ct = x_ct.grad - - # Check that the results are the same - # tolerances quite large on GPU - assert torch.allclose( - y_pred_ct, - y_pred, - atol=5.0e-3, - ), "Concat trick failed, outputs do not match!" - assert torch.allclose( - x_grad_ct, - x_grad, - atol=1.0e-2, - ), "Concat trick failed, gradients do not match!" + # Fix random seeds + torch.manual_seed(0) + torch.cuda.manual_seed(0) + np.random.seed(0) + + # Instantiate the model + model = GraphCastNet( + meshgraph_path=icosphere_path, + static_dataset_path=None, + input_res=(res_h, res_w), + input_dim_grid_nodes=num_channels, + input_dim_mesh_nodes=3, + input_dim_edges=4, + output_dim_grid_nodes=num_channels, + processor_layers=3, + hidden_dim=4, + do_concat_trick=False, + recompute_activation=False, + ).to("cuda") + + # Fix random seeds again + fix_random_seeds() + + # Instantiate the model with concat trick enabled + model_ct = GraphCastNet( + meshgraph_path=icosphere_path, + static_dataset_path=None, + input_res=(res_h, res_w), + input_dim_grid_nodes=num_channels, + input_dim_mesh_nodes=3, + input_dim_edges=4, + output_dim_grid_nodes=num_channels, + processor_layers=3, + hidden_dim=4, + do_concat_trick=True, + recompute_activation=recomp_act, + ).to(device) + + # Forward pass without checkpointing + x.requires_grad_() + y_pred = model(x) + loss = y_pred.sum() + loss.backward() + x_grad = x.grad + x_ct.requires_grad_() + y_pred_ct = model_ct(x_ct) + loss_ct = y_pred_ct.sum() + loss_ct.backward() + x_grad_ct = x_ct.grad + + # Check that the results are the same + # tolerances quite large on GPU + assert torch.allclose( + y_pred_ct, + y_pred, + atol=5.0e-3, + ), "Concat trick failed, outputs do not match!" + + assert torch.allclose( + x_grad_ct, + x_grad, + atol=1.0e-2, + ), "Concat trick failed, gradients do not match!" if __name__ == "__main__": diff --git a/test/models/graphcast/test_cugraphops.py b/test/models/graphcast/test_cugraphops.py index 0da1a3149c..c66822bd8a 100644 --- a/test/models/graphcast/test_cugraphops.py +++ b/test/models/graphcast/test_cugraphops.py @@ -11,20 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys -import numpy as np -import torch -from pytest_utils import import_or_fail -from utils import fix_random_seeds, get_icosphere_path +script_path = os.path.abspath(__file__) +sys.path.append(os.path.join(os.path.dirname(script_path), "..")) + +import common # noqa: E402 +import numpy as np # noqa: E402 +import pytest # noqa: E402 +import torch # noqa: E402 +from pytest_utils import import_or_fail # noqa: E402 +from utils import fix_random_seeds, get_icosphere_path # noqa: E402 @import_or_fail("dgl") -def test_cugraphops(pytestconfig, num_channels=2, res_h=21, res_w=10): +@pytest.mark.parametrize("recomp_act", [False, True]) +@pytest.mark.parametrize("concat_trick", [False, True]) +def test_cugraphops( + pytestconfig, recomp_act, concat_trick, num_channels=2, res_h=21, res_w=10 +): """Test cugraphops""" icosphere_path = get_icosphere_path() from modulus.models.graphcast.graph_cast_net import GraphCastNet + if recomp_act and not common.utils.is_fusion_available("FusionDefinition"): + pytest.skip("nvfuser module is not available or has incorrect version") + # Fix random seeds fix_random_seeds() @@ -32,71 +46,65 @@ def test_cugraphops(pytestconfig, num_channels=2, res_h=21, res_w=10): x = torch.randn(1, num_channels, res_h, res_w, device="cuda") x_dgl = x.clone().detach() - for concat_trick in [False, True]: - for recomp_act in [False, True]: - # Fix random seeds - torch.manual_seed(0) - torch.cuda.manual_seed(0) - np.random.seed(0) - - model = GraphCastNet( - meshgraph_path=icosphere_path, - static_dataset_path=None, - input_res=(res_h, res_w), - input_dim_grid_nodes=num_channels, - input_dim_mesh_nodes=3, - input_dim_edges=4, - output_dim_grid_nodes=num_channels, - processor_layers=3, - hidden_dim=4, - do_concat_trick=concat_trick, - use_cugraphops_decoder=True, - use_cugraphops_encoder=True, - use_cugraphops_processor=True, - recompute_activation=recomp_act, - ).to("cuda") - - # Fix random seeds again - fix_random_seeds() + # Fix random seeds + torch.manual_seed(0) + torch.cuda.manual_seed(0) + np.random.seed(0) - model_dgl = GraphCastNet( - meshgraph_path=icosphere_path, - static_dataset_path=None, - input_res=(res_h, res_w), - input_dim_grid_nodes=num_channels, - input_dim_mesh_nodes=3, - input_dim_edges=4, - output_dim_grid_nodes=num_channels, - processor_layers=3, - hidden_dim=4, - do_concat_trick=concat_trick, - use_cugraphops_decoder=False, - use_cugraphops_encoder=False, - use_cugraphops_processor=False, - recompute_activation=False, - ).to("cuda") + model = GraphCastNet( + meshgraph_path=icosphere_path, + static_dataset_path=None, + input_res=(res_h, res_w), + input_dim_grid_nodes=num_channels, + input_dim_mesh_nodes=3, + input_dim_edges=4, + output_dim_grid_nodes=num_channels, + processor_layers=3, + hidden_dim=4, + do_concat_trick=concat_trick, + use_cugraphops_decoder=True, + use_cugraphops_encoder=True, + use_cugraphops_processor=True, + recompute_activation=recomp_act, + ).to("cuda") - # Forward pass without checkpointing - x.requires_grad_() - y_pred = model(x) - loss = y_pred.sum() - loss.backward() - x_grad = x.grad + # Fix random seeds again + fix_random_seeds() - x_dgl.requires_grad_() - y_pred_dgl = model_dgl(x_dgl) - loss_dgl = y_pred_dgl.sum() - loss_dgl.backward() - x_grad_dgl = x_dgl.grad + model_dgl = GraphCastNet( + meshgraph_path=icosphere_path, + static_dataset_path=None, + input_res=(res_h, res_w), + input_dim_grid_nodes=num_channels, + input_dim_mesh_nodes=3, + input_dim_edges=4, + output_dim_grid_nodes=num_channels, + processor_layers=3, + hidden_dim=4, + do_concat_trick=concat_trick, + use_cugraphops_decoder=False, + use_cugraphops_encoder=False, + use_cugraphops_processor=False, + recompute_activation=False, + ).to("cuda") - # Check that the results are the same - assert torch.allclose( - y_pred_dgl, y_pred, atol=1.0e-6 - ), "testing DGL against cugraph-ops: outputs do not match!" - assert torch.allclose( - x_grad_dgl, x_grad, atol=1.0e-4, rtol=1.0e-3 - ), "testing DGL against cugraph-ops: gradients do not match!" + # Forward pass without checkpointing + x.requires_grad_() + y_pred = model(x) + loss = y_pred.sum() + loss.backward() + x_grad = x.grad + x_dgl.requires_grad_() + y_pred_dgl = model_dgl(x_dgl) + loss_dgl = y_pred_dgl.sum() + loss_dgl.backward() + x_grad_dgl = x_dgl.grad + # Check that the results are the same + assert torch.allclose( + y_pred_dgl, y_pred, atol=1.0e-6 + ), "testing DGL against cugraph-ops: outputs do not match!" -if __name__ == "__main__": - test_cugraphops() + assert torch.allclose( + x_grad_dgl, x_grad, atol=1.0e-4, rtol=1.0e-3 + ), "testing DGL against cugraph-ops: gradients do not match!" diff --git a/test/models/test_layers_activations.py b/test/models/test_layers_activations.py index b93db96840..ae27802ef6 100644 --- a/test/models/test_layers_activations.py +++ b/test/models/test_layers_activations.py @@ -18,12 +18,6 @@ import torch from modulus.models.layers.activations import Identity, SquarePlus, Stan -from modulus.models.layers.fused_silu import ( - FusedSiLU, - FusedSiLU_deriv_1, - FusedSiLU_deriv_2, - FusedSiLU_deriv_3, -) from . import common @@ -77,9 +71,21 @@ def test_activation_squareplus(device): assert common.compare_output(torch.ones_like(invar), outvar) +@pytest.mark.skipif( + not common.utils.is_fusion_available("FusionDefinition"), + reason="nvfuser module is not available or has incorrect version", +) @pytest.mark.parametrize("device", ["cuda:0"]) def test_activation_fused_silu(device): """Test fused SiLU implementation""" + + from modulus.models.layers.fused_silu import ( + FusedSiLU, + FusedSiLU_deriv_1, + FusedSiLU_deriv_2, + FusedSiLU_deriv_3, + ) + input = torch.randn(20, 20, dtype=torch.double, requires_grad=True, device=device) assert torch.autograd.gradcheck( FusedSiLU.apply, input, eps=1e-6, atol=1e-4