diff --git a/examples/trompt_multi_gpu.py b/examples/trompt_multi_gpu.py index b4e1a12e..caaac158 100644 --- a/examples/trompt_multi_gpu.py +++ b/examples/trompt_multi_gpu.py @@ -187,7 +187,7 @@ def run(rank: int, world_size: int, args: argparse.Namespace) -> None: test_metric = torchmetrics.Accuracy(**metrics_kwargs).to(rank) best_val_acc = 0.0 - best_test_acc = 0.0 + test_acc = 0.0 for epoch in range(1, args.epochs + 1): train_loader.sampler.set_epoch(epoch) train_loss, train_acc = train( @@ -206,28 +206,26 @@ def run(rank: int, world_size: int, args: argparse.Namespace) -> None: rank, 'val', ) - test_acc = test( - model, - epoch, - test_loader, - test_metric, - rank, - 'test', - ) if best_val_acc < val_acc: best_val_acc = val_acc - best_test_acc = test_acc + test_acc = test( + model, + epoch, + test_loader, + test_metric, + rank, + 'test', + ) if rank == 0: print(f"Train Loss: {train_loss:.4f}, " f"Train Acc: {train_acc:.4f}, " - f"Val Acc: {val_acc:.4f}, " - f"Test Acc: {test_acc:.4f}") + f"Val Acc: {val_acc:.4f}") lr_scheduler.step() if rank == 0: print(f"Best Val Acc: {best_val_acc:.4f}, " - f"Best Test Acc: {best_test_acc:.4f}") + f"Test Acc: {test_acc:.4f}") dist.destroy_process_group() logging.info("Process group destroyed") @@ -241,7 +239,7 @@ def run(rank: int, world_size: int, args: argparse.Namespace) -> None: parser.add_argument("--num_layers", type=int, default=6) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--lr", type=float, default=0.001) - parser.add_argument("--epochs", type=int, default=200) + parser.add_argument("--epochs", type=int, default=50) parser.add_argument("--seed", type=int, default=0) args = parser.parse_args()