-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
335 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# CIFAR-10 Evaluation Example | ||
|
||
## Single process evaluation | ||
|
||
```bash | ||
python cifar10_npu_eval.py | ||
``` | ||
|
||
## Multiple processes evaluation with torch.distributed | ||
|
||
```bash | ||
python cifar10_eval_torch_npu_dist.py | ||
``` |
63 changes: 63 additions & 0 deletions
63
examples/cifar10_dist_npu_eval/cifar10_eval_torch_npu_dist.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import os | ||
import torch | ||
import torchvision as tv | ||
import tqdm | ||
from torch.utils.data import DataLoader, DistributedSampler | ||
|
||
from mmeval import Accuracy | ||
|
||
|
||
def get_eval_dataloader(rank=0, num_replicas=1): | ||
dataset = tv.datasets.CIFAR10( | ||
root='./', | ||
train=False, | ||
download=True, | ||
transform=tv.transforms.ToTensor()) | ||
dist_sampler = DistributedSampler( | ||
dataset, num_replicas=num_replicas, rank=rank) | ||
data_loader = DataLoader(dataset, batch_size=1, sampler=dist_sampler) | ||
return data_loader, len(dataset) | ||
|
||
|
||
def get_model(pretrained_model_fpath=None): | ||
model = tv.models.resnet18(num_classes=10) | ||
if pretrained_model_fpath is not None: | ||
model.load_state_dict(torch.load(pretrained_model_fpath)) | ||
return model.eval() | ||
|
||
|
||
def eval_fn(rank, process_num): | ||
master_addr = 'localhost' | ||
master_port = 12345 | ||
|
||
os.environ['MASTER_ADDR'] = master_addr | ||
os.environ['MASTER_PORT'] = str(master_port) | ||
|
||
torch.distributed.init_process_group( | ||
backend='hccl', | ||
init_method='env://', | ||
world_size=process_num, | ||
rank=rank) | ||
|
||
num_npus = torch.npu.device_count() | ||
torch.npu.set_device(rank % num_npus) | ||
|
||
eval_dataloader, total_num_samples = get_eval_dataloader(rank, process_num) | ||
model = get_model().npu() | ||
accuracy = Accuracy(topk=(1, 3), dist_backend='npu_dist') | ||
|
||
with torch.no_grad(): | ||
for images, labels in tqdm.tqdm(eval_dataloader, disable=(rank != 0)): | ||
images = images.npu() | ||
labels = labels.npu() | ||
predicted_score = model(images) | ||
accuracy.add(predictions=predicted_score, labels=labels) | ||
|
||
print(accuracy.compute(size=total_num_samples)) | ||
accuracy.reset() | ||
|
||
|
||
if __name__ == '__main__': | ||
process_num = 8 | ||
torch.multiprocessing.spawn( | ||
eval_fn, nprocs=process_num, args=(process_num, )) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
import torchvision as tv | ||
import tqdm | ||
from torch.utils.data import DataLoader | ||
|
||
from mmeval import Accuracy | ||
|
||
|
||
def get_eval_dataloader(): | ||
dataset = tv.datasets.CIFAR10( | ||
root='./', | ||
train=False, | ||
download=True, | ||
transform=tv.transforms.ToTensor()) | ||
return DataLoader(dataset, batch_size=1) | ||
|
||
|
||
def get_model(pretrained_model_fpath=None): | ||
model = tv.models.resnet18(num_classes=10) | ||
if pretrained_model_fpath is not None: | ||
model.load_state_dict(torch.load(pretrained_model_fpath)) | ||
return model.eval() | ||
|
||
|
||
eval_dataloader = get_eval_dataloader() | ||
model = get_model().npu() | ||
accuracy = Accuracy(topk=(1, 3)) | ||
|
||
with torch.no_grad(): | ||
for images, labels in tqdm.tqdm(eval_dataloader): | ||
images = images.npu() | ||
labels = labels.npu() | ||
predicted_score = model(images) | ||
accuracy.add(predictions=predicted_score, labels=labels) | ||
|
||
print(accuracy.compute()) | ||
accuracy.reset() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from typing import TYPE_CHECKING, Any, Tuple, TypeVar, Union | ||
|
||
from mmeval.utils import try_import | ||
from .torch_cpu import TorchCPUDist | ||
|
||
if TYPE_CHECKING: | ||
import torch | ||
import torch_npu | ||
else: | ||
torch = try_import('torch') | ||
torch_npu = try_import('torch_npu') | ||
|
||
Tensor = TypeVar('Tensor', bound='torch.Tensor') | ||
|
||
|
||
class NPUDist(TorchCPUDist): | ||
"""A distributed communication backend for Ascend NPU.""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
if torch_npu is None: | ||
raise ImportError(f'For availability of {self.__class__.__name__},' | ||
' please install ascend pytorch first.') | ||
if not torch.distributed.is_hccl_available(): | ||
raise RuntimeError( | ||
f'For availability of {self.__class__.__name__},' | ||
' make sure torch.distributed.is_hccl_available().') | ||
|
||
def _object_to_tensor(self, obj: Any) -> Tuple[Tensor, Tensor]: | ||
"""Convert the given object to a npu tensor via `pickle.dumps`. | ||
Args: | ||
obj (any): Any pickle-able python object. | ||
Returns: | ||
tuple: A tuple of the tensor converted from given object and the | ||
tensor size. | ||
""" | ||
# Add type annotation make mypy happy | ||
obj_tensor: Tensor | ||
obj_size_tensor: Tensor | ||
obj_tensor, obj_size_tensor = super()._object_to_tensor(obj) | ||
return obj_tensor.npu(), obj_size_tensor.npu() | ||
|
||
def _tensor_to_object(self, tensor: Tensor, | ||
tensor_size: Union[int, Tensor]) -> Any: | ||
"""Convert the given npu tensor to a object via `pickle.loads`. | ||
Args: | ||
tenosr (Tensor): A npu tensor. | ||
tensor_size (int or Tensor): The tensor size of the given Tensor to | ||
be convert object. | ||
Returns: | ||
Any: The object converted from the given npu tensor. | ||
""" | ||
return super()._tensor_to_object(tensor.detach().cpu(), tensor_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.