diff --git a/benchmark/data_frame_benchmark.py b/benchmark/data_frame_benchmark.py index 3916e3486..5685403d4 100644 --- a/benchmark/data_frame_benchmark.py +++ b/benchmark/data_frame_benchmark.py @@ -263,7 +263,7 @@ def train( pred, y = model(tf, mixup_encoded=True) elif isinstance(model, Trompt): # Trompt uses the layer-wise loss - pred = model.forward_stacked(tf) + pred = model(tf) num_layers = pred.size(1) # [batch_size * num_layers, num_classes] pred = pred.view(-1, out_channels) @@ -294,6 +294,8 @@ def test( for tf in loader: tf = tf.to(device) pred = model(tf) + if isinstance(model, Trompt): + pred = pred.mean(dim=1) if dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION: pred = pred.argmax(dim=-1) elif dataset.task_type == TaskType.REGRESSION: diff --git a/benchmark/data_frame_text_benchmark.py b/benchmark/data_frame_text_benchmark.py index 7a8cb05f2..c440024ed 100644 --- a/benchmark/data_frame_text_benchmark.py +++ b/benchmark/data_frame_text_benchmark.py @@ -307,7 +307,7 @@ def train( y = tf.y if isinstance(model, Trompt): # Trompt uses the layer-wise loss - pred = model.forward_stacked(tf) + pred = model(tf) num_layers = pred.size(1) # [batch_size * num_layers, num_classes] pred = pred.view(-1, out_channels) @@ -337,6 +337,10 @@ def test( for tf in loader: tf = tf.to(device) pred = model(tf) + if isinstance(model, Trompt): + # [batch_size, num_layers, out_channels] + # -> [batch_size, out_channels] + pred = pred.mean(dim=1) if dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION: pred = pred.argmax(dim=-1) elif dataset.task_type == TaskType.REGRESSION: diff --git a/examples/trompt.py b/examples/trompt.py index 9d7d6e964..342ab653e 100644 --- a/examples/trompt.py +++ b/examples/trompt.py @@ -90,7 +90,7 @@ def train(epoch: int) -> float: for tf in tqdm(train_loader, desc=f"Epoch: {epoch}"): tf = tf.to(device) # [batch_size, num_layers, num_classes] - out = model.forward_stacked(tf) + out = model(tf) num_layers = out.size(1) # [batch_size * num_layers, num_classes] pred = out.view(-1, dataset.num_classes) @@ -112,7 +112,7 @@ def test(loader: DataLoader) -> float: for tf in loader: tf = tf.to(device) - pred = model(tf) + pred = model(tf).mean(dim=1) pred_class = pred.argmax(dim=-1) accum += float((tf.y == pred_class).sum()) total_count += len(tf.y) diff --git a/examples/trompt_multi_gpu.py b/examples/trompt_multi_gpu.py new file mode 100644 index 000000000..9a18ba435 --- /dev/null +++ b/examples/trompt_multi_gpu.py @@ -0,0 +1,245 @@ +import argparse +import logging +import os +import os.path as osp + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn.functional as F +import torchmetrics +from torch.nn.parallel import DistributedDataParallel +from torch.optim.lr_scheduler import ExponentialLR +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + +from torch_frame.data import DataLoader +from torch_frame.datasets import TabularBenchmark +from torch_frame.nn import Trompt + + +def prepare_dataset(dataset_str: str) -> TabularBenchmark: + path = osp.join( + osp.dirname(osp.realpath(__file__)), + "..", + "data", + dataset_str, + ) + materialized_path = osp.join(path, 'materialized_data.pt') + if dist.get_rank() == 0: + logging.info(f"Preparing dataset '{dataset_str}' from '{path}'") + dataset = TabularBenchmark(root=path, name=dataset_str) + logging.info("Materializing dataset") + dataset.materialize(path=materialized_path) + + dist.barrier() + if dist.get_rank() != 0: + logging.info(f"Preparing dataset '{dataset_str}' from '{path}'") + dataset = TabularBenchmark(root=path, name=dataset_str) + logging.info("Loading materialized dataset") + dataset.materialize(path=materialized_path) + + dist.barrier() + return dataset + + +def train( + model: DistributedDataParallel, + epoch: int, + loader: DataLoader, + optimizer: torch.optim.Optimizer, + num_classes: int, + metric: torchmetrics.Metric, + rank: int, +) -> float: + model.train() + loss_accum = torch.tensor(0.0, device=rank, dtype=torch.float32) + for tf in tqdm(loader, desc=f"Epoch {epoch:02d}", disable=rank != 0): + tf = tf.to(rank) + # [batch_size, num_layers, num_classes] + out = model(tf) + with torch.no_grad(): + metric.update(out.mean(dim=1).argmax(dim=-1), tf.y) + num_layers = out.size(1) + # [batch_size * num_layers, num_classes] + pred = out.view(-1, num_classes) + y = tf.y.repeat_interleave(num_layers) + # Layer-wise logit loss + loss = F.cross_entropy(pred, y) + loss.backward() + optimizer.step() + optimizer.zero_grad() + loss_accum += loss + + # The number of samples is guaranteed to be the same across all ranks + # because of DistributedSampler(drop_last=True). + dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG) + metric_value = metric.compute() + metric.reset() + return loss_accum, metric_value + + +@torch.no_grad() +def test( + model: DistributedDataParallel, + epoch: int, + loader: DataLoader, + metric: torchmetrics.Metric, + rank: int, + desc: str, +) -> float: + model.eval() + for tf in tqdm( + loader, + desc=f"Epoch {epoch:02d} ({desc})", + disable=rank != 0, + ): + tf = tf.to(rank) + # [batch_size, num_layers, num_classes] -> [batch_size, num_classes] + pred = model(tf).mean(dim=1) + pred_class = pred.argmax(dim=-1) + metric.update(pred_class, tf.y) + + metric_value = metric.compute() + metric.reset() + return metric_value + + +def run(rank, world_size, args) -> None: + dist.init_process_group( + backend='nccl', + init_method='env://', + world_size=world_size, + rank=rank, + ) + logging.basicConfig( + format=f"[rank={rank}] [%(asctime)s] %(levelname)s: %(message)s", + level=logging.INFO, + ) + logger = logging.getLogger(__name__) + logger.info(f"Running on rank {rank} of {world_size}") + dataset = prepare_dataset(args.dataset) + assert dataset.task_type.is_classification + + # Ensure train, val and test splits are the same across all ranks by + # setting the seed before shuffling. + torch.manual_seed(args.seed) + dataset = dataset.shuffle() + train_dataset, val_dataset, test_dataset = ( + dataset[:0.7], + dataset[0.7:0.79], + dataset[0.79:], + ) + train_loader = DataLoader( + train_dataset.tensor_frame, + batch_size=args.batch_size, + sampler=DistributedSampler( + train_dataset, + shuffle=True, + drop_last=True, + ), + ) + val_loader = DataLoader( + val_dataset.tensor_frame, + batch_size=args.batch_size, + sampler=DistributedSampler( + val_dataset, + shuffle=False, + drop_last=False, + ), + ) + test_loader = DataLoader( + test_dataset.tensor_frame, + batch_size=args.batch_size, + sampler=DistributedSampler( + test_dataset, + shuffle=False, + drop_last=False, + ), + ) + model = Trompt( + channels=args.channels, + out_channels=dataset.num_classes, + num_prompts=args.num_prompts, + num_layers=args.num_layers, + col_stats=dataset.col_stats, + col_names_dict=train_dataset.tensor_frame.col_names_dict, + ).to(rank) + model = DistributedDataParallel(model, device_ids=[rank]) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + lr_scheduler = ExponentialLR(optimizer, gamma=0.95) + + metrics_kwargs = { + "task": "multiclass", + "num_classes": dataset.num_classes, + } + train_metric = torchmetrics.Accuracy(**metrics_kwargs).to(rank) + val_metric = torchmetrics.Accuracy(**metrics_kwargs).to(rank) + test_metric = torchmetrics.Accuracy(**metrics_kwargs).to(rank) + + best_val_acc = 0.0 + best_test_acc = 0.0 + for epoch in range(1, args.epochs + 1): + train_loader.sampler.set_epoch(epoch) + train_loss, train_acc = train( + model, + epoch, + train_loader, + optimizer, + dataset.num_classes, + train_metric, + rank, + ) + val_acc = test( + model, + epoch, + val_loader, + val_metric, + rank, + 'val', + ) + test_acc = test( + model, + epoch, + test_loader, + test_metric, + rank, + 'test', + ) + if best_val_acc < val_acc: + best_val_acc = val_acc + best_test_acc = test_acc + if rank == 0: + print(f"Train Loss: {train_loss:.4f}, " + f"Train Acc: {train_acc:.4f}, " + f"Val Acc: {val_acc:.4f}, " + f"Test Acc: {test_acc:.4f}") + + lr_scheduler.step() + + if rank == 0: + print(f"Best Val Acc: {best_val_acc:.4f}, " + f"Best Test Acc: {best_test_acc:.4f}") + + dist.destroy_process_group() + logging.info("Process group destroyed") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, default="california") + parser.add_argument("--channels", type=int, default=128) + parser.add_argument("--num_prompts", type=int, default=128) + parser.add_argument("--num_layers", type=int, default=6) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--epochs", type=int, default=200) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--compile", action="store_true") + args = parser.parse_args() + + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + world_size = torch.cuda.device_count() + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) diff --git a/test/nn/models/test_trompt.py b/test/nn/models/test_trompt.py index 3ed78a966..b117c3383 100644 --- a/test/nn/models/test_trompt.py +++ b/test/nn/models/test_trompt.py @@ -47,7 +47,5 @@ def test_trompt(batch_size, use_stype_encoder_dicts): stype_encoder_dicts=stype_encoder_dicts, ) model.reset_parameters() - out = model.forward_stacked(tensor_frame) - assert out.shape == (batch_size, num_layers, out_channels) pred = model(tensor_frame) - assert pred.shape == (batch_size, out_channels) + assert pred.shape == (batch_size, num_layers, out_channels) diff --git a/torch_frame/nn/models/trompt.py b/torch_frame/nn/models/trompt.py index ccb39524e..a7cf36ca8 100644 --- a/torch_frame/nn/models/trompt.py +++ b/torch_frame/nn/models/trompt.py @@ -122,7 +122,7 @@ def reset_parameters(self) -> None: trompt_conv.reset_parameters() self.trompt_decoder.reset_parameters() - def forward_stacked(self, tf: TensorFrame) -> Tensor: + def forward(self, tf: TensorFrame) -> Tensor: r"""Transforming :class:`TensorFrame` object into a series of output predictions at each layer. Used during training to compute layer-wise loss. @@ -152,6 +152,3 @@ def forward_stacked(self, tf: TensorFrame) -> Tensor: # [batch_size, num_layers, out_channels] stacked_out = torch.cat(outs, dim=1) return stacked_out - - def forward(self, tf: TensorFrame) -> Tensor: - return self.forward_stacked(tf).mean(dim=1)