Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Dec 27, 2024
1 parent a494e01 commit 1941549
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions examples/trompt_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def run(rank: int, world_size: int, args: argparse.Namespace) -> None:
test_metric = torchmetrics.Accuracy(**metrics_kwargs).to(rank)

best_val_acc = 0.0
best_test_acc = 0.0
test_acc = 0.0
for epoch in range(1, args.epochs + 1):
train_loader.sampler.set_epoch(epoch)
train_loss, train_acc = train(
Expand All @@ -206,28 +206,26 @@ def run(rank: int, world_size: int, args: argparse.Namespace) -> None:
rank,
'val',
)
test_acc = test(
model,
epoch,
test_loader,
test_metric,
rank,
'test',
)
if best_val_acc < val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
test_acc = test(
model,
epoch,
test_loader,
test_metric,
rank,
'test',
)
if rank == 0:
print(f"Train Loss: {train_loss:.4f}, "
f"Train Acc: {train_acc:.4f}, "
f"Val Acc: {val_acc:.4f}, "
f"Test Acc: {test_acc:.4f}")
f"Val Acc: {val_acc:.4f}")

lr_scheduler.step()

if rank == 0:
print(f"Best Val Acc: {best_val_acc:.4f}, "
f"Best Test Acc: {best_test_acc:.4f}")
f"Test Acc: {test_acc:.4f}")

dist.destroy_process_group()
logging.info("Process group destroyed")
Expand All @@ -241,7 +239,7 @@ def run(rank: int, world_size: int, args: argparse.Namespace) -> None:
parser.add_argument("--num_layers", type=int, default=6)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()

Expand Down

0 comments on commit 1941549

Please sign in to comment.