diff --git a/src/seamless_communication/cli/m4t/classification_head/model.py b/src/seamless_communication/cli/m4t/classification_head/model.py index a6d31d74..cb7de345 100644 --- a/src/seamless_communication/cli/m4t/classification_head/model.py +++ b/src/seamless_communication/cli/m4t/classification_head/model.py @@ -8,14 +8,19 @@ def __init__(self, embed_dim, n_layers, n_classes, n_heads = 4): self.attn = nn.MultiheadAttention(embed_dim, num_heads=n_heads) self.layers = nn.ModuleList( - [ nn.Linear(embed_dim, embed_dim) for _ in range(n_layers) ] + \ - [ nn.Linear(embed_dim, n_classes) ]) + [ nn.Sequential( + nn.Linear(embed_dim, embed_dim), + nn.BatchNorm1d(embed_dim), # normalize batch + nn.ReLU(), # activation function + nn.Dropout(0.5) # prevent overfitting + ) for _ in range(n_layers) + ] + [ nn.Linear(embed_dim, n_classes) ]) def forward(self, x): # (Batch, Seq, Embed) x, _ = self.attn(x, x, x) x = x[:, 0] for layer in self.layers: - x = nn.functional.relu(layer(x)) + x = layer(x) return nn.functional.softmax(x).float() \ No newline at end of file