Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
weihua916 committed Oct 14, 2023
1 parent c154264 commit 9af7d25
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/data_frame_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@
col_stats = dataset.col_stats
elif args.model_type == 'TabTransformer':
model_search_space = {
'channels': [32, 64, 128],
'channels': [64, 128, 256],
'num_layers': [4, 6, 8],
'num_heads': [2, 4, 8],
'encoder_pad_size': [2],
'encoder_pad_size': [2, 4],
'attn_dropout': [0, 0.2],
'ffn_dropout': [0, 0.2],
}
Expand Down
1 change: 1 addition & 0 deletions torch_frame/nn/models/tab_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(

def reset_parameters(self):
self.cat_encoder.reset_parameters()
torch.nn.init.normal_(self.pad_embedding.weight, std=0.01)
self.num_norm.reset_parameters()
for tab_transformer_conv in self.tab_transformer_convs:
tab_transformer_conv.reset_parameters()
Expand Down

0 comments on commit 9af7d25

Please sign in to comment.