From e3431d29c42e4e200a6f6db993f61380ea86dbaa Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 27 Dec 2024 01:09:06 +0000 Subject: [PATCH] update --- examples/trompt_multi_gpu.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/trompt_multi_gpu.py b/examples/trompt_multi_gpu.py index 9a18ba43..c9cadcfd 100644 --- a/examples/trompt_multi_gpu.py +++ b/examples/trompt_multi_gpu.py @@ -48,19 +48,24 @@ def train( epoch: int, loader: DataLoader, optimizer: torch.optim.Optimizer, - num_classes: int, metric: torchmetrics.Metric, rank: int, ) -> float: model.train() loss_accum = torch.tensor(0.0, device=rank, dtype=torch.float32) - for tf in tqdm(loader, desc=f"Epoch {epoch:02d}", disable=rank != 0): + for tf in tqdm( + loader, + desc=f"Epoch {epoch:02d} (train)", + disable=rank != 0, + ): tf = tf.to(rank) # [batch_size, num_layers, num_classes] out = model(tf) + with torch.no_grad(): metric.update(out.mean(dim=1).argmax(dim=-1), tf.y) - num_layers = out.size(1) + + _, num_layers, num_classes = out.size() # [batch_size * num_layers, num_classes] pred = out.view(-1, num_classes) y = tf.y.repeat_interleave(num_layers) @@ -105,7 +110,7 @@ def test( return metric_value -def run(rank, world_size, args) -> None: +def run(rank: int, world_size: int, args: argparse.Namespace) -> None: dist.init_process_group( backend='nccl', init_method='env://', @@ -113,7 +118,8 @@ def run(rank, world_size, args) -> None: rank=rank, ) logging.basicConfig( - format=f"[rank={rank}] [%(asctime)s] %(levelname)s: %(message)s", + format=(f"[rank={rank}/{world_size}] " + f"[%(asctime)s] %(levelname)s: %(message)s"), level=logging.INFO, ) logger = logging.getLogger(__name__) @@ -122,7 +128,7 @@ def run(rank, world_size, args) -> None: assert dataset.task_type.is_classification # Ensure train, val and test splits are the same across all ranks by - # setting the seed before shuffling. + # setting the seed on each rank. torch.manual_seed(args.seed) dataset = dataset.shuffle() train_dataset, val_dataset, test_dataset = ( @@ -186,7 +192,6 @@ def run(rank, world_size, args) -> None: epoch, train_loader, optimizer, - dataset.num_classes, train_metric, rank, )