Skip to content

Commit

Permalink
change model to increase train loss
Browse files Browse the repository at this point in the history
  • Loading branch information
am831 committed May 19, 2024
1 parent 19653d7 commit 588b0a8
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/seamless_communication/cli/m4t/classification_head/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@ def __init__(self, embed_dim, n_layers, n_classes, n_heads = 4):

self.attn = nn.MultiheadAttention(embed_dim, num_heads=n_heads)
self.layers = nn.ModuleList(
[ nn.Linear(embed_dim, embed_dim) for _ in range(n_layers) ] + \
[ nn.Linear(embed_dim, n_classes) ])
[ nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.BatchNorm1d(embed_dim), # normalize batch
nn.ReLU(), # activation function
nn.Dropout(0.5) # prevent overfitting
) for _ in range(n_layers)
] + [ nn.Linear(embed_dim, n_classes) ])

def forward(self, x):
# (Batch, Seq, Embed)
x, _ = self.attn(x, x, x)
x = x[:, 0]
for layer in self.layers:
x = nn.functional.relu(layer(x))
x = layer(x)
return nn.functional.softmax(x).float()

0 comments on commit 588b0a8

Please sign in to comment.