diff --git a/wespeaker/models/ecapa_tdnn.py b/wespeaker/models/ecapa_tdnn.py index 99824f9d..8c76222f 100644 --- a/wespeaker/models/ecapa_tdnn.py +++ b/wespeaker/models/ecapa_tdnn.py @@ -217,20 +217,21 @@ def _get_frame_level_feat(self, x): out = torch.cat([out2, out3, out4], dim=1) out = self.conv(out) - return out + return out, out4 def get_frame_level_feat(self, x): # for outer interface - out = self._get_frame_level_feat(x).permute(0, 2, 1) + out = self._get_frame_level_feat(x)[0].permute(0, 2, 1) return out # (B, T, D) def forward(self, x): - out = F.relu(self._get_frame_level_feat(x)) + out, out4 = self._get_frame_level_feat(x) + out = F.relu(out) out = self.bn(self.pool(out)) out = self.linear(out) if self.emb_bn: out = self.bn2(out) - return out + return out4, out def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func='ASTP', emb_bn=False):