diff --git a/examples/cifar10_dist_npu_eval/README.md b/examples/cifar10_dist_npu_eval/README.md new file mode 100644 index 00000000..56c7da1e --- /dev/null +++ b/examples/cifar10_dist_npu_eval/README.md @@ -0,0 +1,13 @@ +# CIFAR-10 Evaluation Example + +## Single process evaluation + +```bash +python cifar10_npu_eval.py +``` + +## Multiple processes evaluation with torch.distributed + +```bash +python cifar10_eval_torch_npu_dist.py +``` diff --git a/examples/cifar10_dist_npu_eval/cifar10_eval_torch_npu_dist.py b/examples/cifar10_dist_npu_eval/cifar10_eval_torch_npu_dist.py new file mode 100644 index 00000000..d38233d5 --- /dev/null +++ b/examples/cifar10_dist_npu_eval/cifar10_eval_torch_npu_dist.py @@ -0,0 +1,63 @@ +import os +import torch +import torchvision as tv +import tqdm +from torch.utils.data import DataLoader, DistributedSampler + +from mmeval import Accuracy + + +def get_eval_dataloader(rank=0, num_replicas=1): + dataset = tv.datasets.CIFAR10( + root='./', + train=False, + download=True, + transform=tv.transforms.ToTensor()) + dist_sampler = DistributedSampler( + dataset, num_replicas=num_replicas, rank=rank) + data_loader = DataLoader(dataset, batch_size=1, sampler=dist_sampler) + return data_loader, len(dataset) + + +def get_model(pretrained_model_fpath=None): + model = tv.models.resnet18(num_classes=10) + if pretrained_model_fpath is not None: + model.load_state_dict(torch.load(pretrained_model_fpath)) + return model.eval() + + +def eval_fn(rank, process_num): + master_addr = 'localhost' + master_port = 12345 + + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = str(master_port) + + torch.distributed.init_process_group( + backend='hccl', + init_method='env://', + world_size=process_num, + rank=rank) + + num_npus = torch.npu.device_count() + torch.npu.set_device(rank % num_npus) + + eval_dataloader, total_num_samples = get_eval_dataloader(rank, process_num) + model = get_model().npu() + accuracy = Accuracy(topk=(1, 3), dist_backend='npu_dist') + + with torch.no_grad(): + for images, labels in tqdm.tqdm(eval_dataloader, disable=(rank != 0)): + images = images.npu() + labels = labels.npu() + predicted_score = model(images) + accuracy.add(predictions=predicted_score, labels=labels) + + print(accuracy.compute(size=total_num_samples)) + accuracy.reset() + + +if __name__ == '__main__': + process_num = 8 + torch.multiprocessing.spawn( + eval_fn, nprocs=process_num, args=(process_num, )) diff --git a/examples/cifar10_dist_npu_eval/cifar10_npu_eval.py b/examples/cifar10_dist_npu_eval/cifar10_npu_eval.py new file mode 100644 index 00000000..4d244cdb --- /dev/null +++ b/examples/cifar10_dist_npu_eval/cifar10_npu_eval.py @@ -0,0 +1,37 @@ +import torch +import torchvision as tv +import tqdm +from torch.utils.data import DataLoader + +from mmeval import Accuracy + + +def get_eval_dataloader(): + dataset = tv.datasets.CIFAR10( + root='./', + train=False, + download=True, + transform=tv.transforms.ToTensor()) + return DataLoader(dataset, batch_size=1) + + +def get_model(pretrained_model_fpath=None): + model = tv.models.resnet18(num_classes=10) + if pretrained_model_fpath is not None: + model.load_state_dict(torch.load(pretrained_model_fpath)) + return model.eval() + + +eval_dataloader = get_eval_dataloader() +model = get_model().npu() +accuracy = Accuracy(topk=(1, 3)) + +with torch.no_grad(): + for images, labels in tqdm.tqdm(eval_dataloader): + images = images.npu() + labels = labels.npu() + predicted_score = model(images) + accuracy.add(predictions=predicted_score, labels=labels) + +print(accuracy.compute()) +accuracy.reset() diff --git a/mmeval/core/dist.py b/mmeval/core/dist.py index 2a3e2f6d..1d501aab 100644 --- a/mmeval/core/dist.py +++ b/mmeval/core/dist.py @@ -2,9 +2,9 @@ from typing import List, Optional, no_type_check -from .dist_backends import (BaseDistBackend, MPI4PyDist, NonDist, OneFlowDist, - PaddleDist, TFHorovodDist, TorchCPUDist, - TorchCUDADist) +from .dist_backends import (BaseDistBackend, MPI4PyDist, NonDist, NPUDist, + OneFlowDist, PaddleDist, TFHorovodDist, + TorchCPUDist, TorchCUDADist) _DIST_BACKENDS = { 'non_dist': NonDist, @@ -14,6 +14,7 @@ 'torch_cpu': TorchCPUDist, 'torch_cuda': TorchCUDADist, 'paddle_dist': PaddleDist, + 'npu_dist': NPUDist } _DEFAULT_BACKEND = 'non_dist' diff --git a/mmeval/core/dist_backends/__init__.py b/mmeval/core/dist_backends/__init__.py index d35b7130..02d7e878 100644 --- a/mmeval/core/dist_backends/__init__.py +++ b/mmeval/core/dist_backends/__init__.py @@ -3,6 +3,7 @@ from .base_backend import BaseDistBackend, TensorBaseDistBackend from .mpi4py import MPI4PyDist from .non_dist import NonDist +from .npu_dist import NPUDist from .oneflow_dist import OneFlowDist from .paddle_dist import PaddleDist from .tf_horovod import TFHorovodDist @@ -12,5 +13,5 @@ __all__ = [ 'BaseDistBackend', 'TensorBaseDistBackend', 'MPI4PyDist', 'NonDist', 'OneFlowDist', 'TFHorovodDist', 'TorchCPUDist', 'TorchCUDADist', - 'PaddleDist' + 'PaddleDist', 'NPUDist' ] diff --git a/mmeval/core/dist_backends/npu_dist.py b/mmeval/core/dist_backends/npu_dist.py new file mode 100644 index 00000000..0c4086a4 --- /dev/null +++ b/mmeval/core/dist_backends/npu_dist.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import TYPE_CHECKING, Any, Tuple, TypeVar, Union + +from mmeval.utils import try_import +from .torch_cpu import TorchCPUDist + +if TYPE_CHECKING: + import torch + import torch_npu +else: + torch = try_import('torch') + torch_npu = try_import('torch_npu') + +Tensor = TypeVar('Tensor', bound='torch.Tensor') + + +class NPUDist(TorchCPUDist): + """A distributed communication backend for Ascend NPU.""" + + def __init__(self) -> None: + super().__init__() + if torch_npu is None: + raise ImportError(f'For availability of {self.__class__.__name__},' + ' please install ascend pytorch first.') + if not torch.distributed.is_hccl_available(): + raise RuntimeError( + f'For availability of {self.__class__.__name__},' + ' make sure torch.distributed.is_hccl_available().') + + def _object_to_tensor(self, obj: Any) -> Tuple[Tensor, Tensor]: + """Convert the given object to a npu tensor via `pickle.dumps`. + + Args: + obj (any): Any pickle-able python object. + + Returns: + tuple: A tuple of the tensor converted from given object and the + tensor size. + """ + # Add type annotation make mypy happy + obj_tensor: Tensor + obj_size_tensor: Tensor + obj_tensor, obj_size_tensor = super()._object_to_tensor(obj) + return obj_tensor.npu(), obj_size_tensor.npu() + + def _tensor_to_object(self, tensor: Tensor, + tensor_size: Union[int, Tensor]) -> Any: + """Convert the given npu tensor to a object via `pickle.loads`. + + Args: + tenosr (Tensor): A npu tensor. + tensor_size (int or Tensor): The tensor size of the given Tensor to + be convert object. + + Returns: + Any: The object converted from the given npu tensor. + """ + return super()._tensor_to_object(tensor.detach().cpu(), tensor_size) diff --git a/mmeval/metrics/accuracy.py b/mmeval/metrics/accuracy.py index 48c1149b..8ab0674e 100644 --- a/mmeval/metrics/accuracy.py +++ b/mmeval/metrics/accuracy.py @@ -25,6 +25,8 @@ jax = try_import('jax') flow = try_import('oneflow') +torch = try_import('torch') + @overload @dispatch diff --git a/mmeval/metrics/average_precision.py b/mmeval/metrics/average_precision.py index 62efcc0d..b09918a1 100644 --- a/mmeval/metrics/average_precision.py +++ b/mmeval/metrics/average_precision.py @@ -16,6 +16,8 @@ torch = try_import('torch') flow = try_import('oneflow') +torch_npu = try_import('torch_npu') + NUMPY_IMPL_HINTS = Tuple[Union[np.ndarray, np.number], Union[np.ndarray, np.number]] TORCH_IMPL_HINTS = Tuple['torch.Tensor', 'torch.Tensor'] diff --git a/mmeval/metrics/end_point_error.py b/mmeval/metrics/end_point_error.py index 07fcb77c..928609b6 100644 --- a/mmeval/metrics/end_point_error.py +++ b/mmeval/metrics/end_point_error.py @@ -15,6 +15,8 @@ torch = try_import('torch') flow = try_import('oneflow') +torch_npu = try_import('torch_npu') + class EndPointError(BaseMetric): """EndPointError evaluation metric. diff --git a/mmeval/metrics/f1_score.py b/mmeval/metrics/f1_score.py index 75401591..9718eb3f 100644 --- a/mmeval/metrics/f1_score.py +++ b/mmeval/metrics/f1_score.py @@ -14,6 +14,8 @@ torch = try_import('torch') flow = try_import('oneflow') +torch_npu = try_import('torch_npu') + class F1Score(BaseMetric): """Compute F1 scores. diff --git a/mmeval/metrics/mean_iou.py b/mmeval/metrics/mean_iou.py index be17b987..c431cc17 100644 --- a/mmeval/metrics/mean_iou.py +++ b/mmeval/metrics/mean_iou.py @@ -24,6 +24,8 @@ tf = try_import('tensorflow') flow = try_import('oneflow') +torch_npu = try_import('torch_npu') + class MeanIoU(BaseMetric): """MeanIoU evaluation metric. diff --git a/mmeval/metrics/perplexity.py b/mmeval/metrics/perplexity.py index 00d2eadc..1443ffba 100644 --- a/mmeval/metrics/perplexity.py +++ b/mmeval/metrics/perplexity.py @@ -21,6 +21,8 @@ tf = try_import('tensorflow') flow = try_import('oneflow') +torch_npu = try_import('torch_npu') + def softmax(x: np.ndarray) -> np.ndarray: """Compute the softmax function. diff --git a/mmeval/metrics/precision_recall_f1score.py b/mmeval/metrics/precision_recall_f1score.py index 3ca661b7..c08df603 100644 --- a/mmeval/metrics/precision_recall_f1score.py +++ b/mmeval/metrics/precision_recall_f1score.py @@ -21,6 +21,8 @@ flow = try_import('oneflow') of_F = try_import('oneflow.nn.functional') +torch_npu = try_import('torch_npu') + NUMPY_IMPL_HINTS = Tuple[Union[np.ndarray, np.number], Union[np.ndarray, np.number]] TORCH_IMPL_HINTS = Tuple['torch.Tensor', 'torch.Tensor'] diff --git a/tests/test_core/test_dist_backends/test_npu_dist_single_node.py b/tests/test_core/test_dist_backends/test_npu_dist_single_node.py new file mode 100644 index 00000000..45ed92de --- /dev/null +++ b/tests/test_core/test_dist_backends/test_npu_dist_single_node.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import os +import pytest + +from mmeval.core.dist_backends.npu_dist import NPUDist + +# check if current process is launch via mpirun +if os.environ.get('OMPI_COMM_WORLD_SIZE', '0') != '0': + pytest.skip(allow_module_level=True) + +torch_npu = pytest.importorskip('torch_npu') + +torch = pytest.importorskip('torch') +torch_dist = pytest.importorskip('torch.distributed') +mp = pytest.importorskip('torch.multiprocessing') + + +def _init_torch_dist(rank, world_size, comm_backend, port): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = str(port) + + torch_dist.init_process_group( + backend=comm_backend, + init_method='env://', + world_size=world_size, + rank=rank) + + if comm_backend == 'hccl': + num_gpus = torch.npu.device_count() + torch.npu.set_device(rank % num_gpus) + + +def _create_obj_list(world_size): + obj_list = [] + for idx in range(world_size): + obj = dict() + obj['rank'] = idx + obj['world_size'] = world_size + obj['data'] = [i for i in range(idx)] + obj_list.append(obj) + return obj_list + + +def _torch_dist_all_gather_fn(rank, world_size, comm_backend, port): + _init_torch_dist(rank, world_size, comm_backend, port) + dist_comm = NPUDist() + + assert dist_comm.is_initialized + assert dist_comm.rank == rank + assert dist_comm.world_size == world_size + + obj_list = _create_obj_list(world_size) + local_obj = obj_list[rank] + print(f'rank {rank}, local_obj {local_obj}') + + gather_obj_list = dist_comm.all_gather_object(local_obj) + print(f'rank {rank}, gather_obj_list {gather_obj_list}') + + assert gather_obj_list == obj_list + + +def _torch_dist_broadcast_fn(rank, world_size, comm_backend, port): + _init_torch_dist(rank, world_size, comm_backend, port) + dist_comm = NPUDist() + + assert dist_comm.is_initialized + assert dist_comm.rank == rank + assert dist_comm.world_size == world_size + + rank_0_obj = {'rank': 0} + + if rank == 0: + obj = rank_0_obj + else: + obj = None + + print(f'rank {rank}, obj {obj}') + broadcast_obj = dist_comm.broadcast_object(obj, src=0) + print(f'rank {rank}, broadcast_obj {broadcast_obj}') + + assert broadcast_obj == rank_0_obj + + +@pytest.mark.skipif( + not torch_dist.is_hccl_available(), + reason='HCCL backend is not available.') +@pytest.mark.skipif( + torch.npu.device_count() < 1, + reason='NPU device count must greater than 0.') +@pytest.mark.parametrize( + argnames=['process_num', 'comm_port'], + argvalues=[ + pytest.param( + 1, + 2347, + marks=pytest.mark.skipif( + torch.npu.device_count() < 1, + reason='npu device count must greater than 0.')), + pytest.param( + 2, + 2347, + marks=pytest.mark.skipif( + torch.npu.device_count() < 2, + reason='NPU device count must greater than 2.')) + ]) +def test_hccl_all_gather_object(process_num, comm_port): + comm_backend = 'hccl' + mp.spawn( + _torch_dist_all_gather_fn, + nprocs=process_num, + args=(process_num, comm_backend, comm_port)) + + +@pytest.mark.skipif( + not torch_dist.is_hccl_available(), + reason='HCCL backend is not available.') +@pytest.mark.parametrize( + argnames=['process_num', 'comm_port'], + argvalues=[ + pytest.param( + 1, + 2350, + marks=pytest.mark.skipif( + torch.npu.device_count() < 1, + reason='npu device count must greater than 0.')), + pytest.param( + 2, + 2350, + marks=pytest.mark.skipif( + torch.npu.device_count() < 2, + reason='NPU device count must greater than 2.')) + ]) +def test_hccl_broadcast_object(process_num, comm_port): + comm_backend = 'hccl' + mp.spawn( + _torch_dist_broadcast_fn, + nprocs=process_num, + args=(process_num, comm_backend, comm_port)) + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--capture=no'])