diff --git a/examples/data_frame_benchmark.py b/examples/data_frame_benchmark.py index d989ba2a..136dde9d 100644 --- a/examples/data_frame_benchmark.py +++ b/examples/data_frame_benchmark.py @@ -153,10 +153,10 @@ col_stats = dataset.col_stats elif args.model_type == 'TabTransformer': model_search_space = { - 'channels': [32, 64, 128], + 'channels': [64, 128, 256], 'num_layers': [4, 6, 8], 'num_heads': [2, 4, 8], - 'encoder_pad_size': [2], + 'encoder_pad_size': [2, 4], 'attn_dropout': [0, 0.2], 'ffn_dropout': [0, 0.2], } diff --git a/torch_frame/nn/models/tab_transformer.py b/torch_frame/nn/models/tab_transformer.py index 8016991b..3c7ab2d6 100644 --- a/torch_frame/nn/models/tab_transformer.py +++ b/torch_frame/nn/models/tab_transformer.py @@ -96,6 +96,7 @@ def __init__( def reset_parameters(self): self.cat_encoder.reset_parameters() + torch.nn.init.normal_(self.pad_embedding.weight, std=0.01) self.num_norm.reset_parameters() for tab_transformer_conv in self.tab_transformer_convs: tab_transformer_conv.reset_parameters()