diff --git a/i6_models/assemblies/e_branchformer/__init__.py b/i6_models/assemblies/e_branchformer/__init__.py new file mode 100644 index 00000000..4e6b16d8 --- /dev/null +++ b/i6_models/assemblies/e_branchformer/__init__.py @@ -0,0 +1 @@ +from .e_branchformer_v1 import * diff --git a/i6_models/assemblies/e_branchformer/e_branchformer_v1.py b/i6_models/assemblies/e_branchformer/e_branchformer_v1.py new file mode 100644 index 00000000..6061d5c6 --- /dev/null +++ b/i6_models/assemblies/e_branchformer/e_branchformer_v1.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +__all__ = [ + "EbranchformerBlockV1Config", + "EbranchformerBlockV1", + "EbranchformerEncoderV1Config", + "EbranchformerEncoderV1", +] + +import torch +from torch import nn +from dataclasses import dataclass +from typing import Tuple + +from i6_models.config import ModelConfiguration, ModuleFactoryV1 +from i6_models.parts.conformer import ( + ConformerMHSAV1 as MHSAV1, + ConformerMHSAV1Config as MHSAV1Config, + ConformerPositionwiseFeedForwardV1 as PositionwiseFeedForwardV1, + ConformerPositionwiseFeedForwardV1Config as PositionwiseFeedForwardV1Config, +) +from i6_models.parts.e_branchformer import ( + ConvolutionalGatingMLPV1Config, + ConvolutionalGatingMLPV1, + MergerV1Config, + MergerV1, +) + + +@dataclass +class EbranchformerBlockV1Config(ModelConfiguration): + """ + Attributes: + ff_cfg: Configuration for PositionwiseFeedForwardV1 module + mhsa_cfg: Configuration for MHSAV1 module + cgmlp_cfg: Configuration for ConvolutionalGatingMLPV1 module + merger_cfg: Configuration for MergerV1 module + """ + + ff_cfg: PositionwiseFeedForwardV1Config + mhsa_cfg: MHSAV1Config + cgmlp_cfg: ConvolutionalGatingMLPV1Config + merger_cfg: MergerV1Config + + +class EbranchformerBlockV1(nn.Module): + """ + Ebranchformer block module + """ + + def __init__(self, cfg: EbranchformerBlockV1Config): + """ + :param cfg: e-branchformer block configuration with subunits for the different e-branchformer parts + """ + super().__init__() + self.ff_1 = PositionwiseFeedForwardV1(cfg=cfg.ff_cfg) + self.mhsa = MHSAV1(cfg=cfg.mhsa_cfg) + self.cgmlp = ConvolutionalGatingMLPV1(model_cfg=cfg.cgmlp_cfg) + self.merger = MergerV1(model_cfg=cfg.merger_cfg) + self.ff_2 = PositionwiseFeedForwardV1(cfg=cfg.ff_cfg) + self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim) + + def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor: + """ + :param tensor: input tensor of shape [B, T, F] + :param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T] + :return: torch.Tensor of shape [B, T, F] + """ + x = 0.5 * self.ff1(x) + x # [B, T, F] + x_1 = self.mhsa(x, sequence_mask) # [B, T, F] + x_2 = self.cgmlp(x) # [B, T, F] + x = self.merger(x_1, x_2) + x # [B, T, F] + x = 0.5 * self.ff2(x) + x # [B, T, F] + x = self.final_layer_norm(x) # [B, T, F] + return x + + +class EbranchformerEncoderV1Config(ModelConfiguration): + """ + Attributes: + num_layers: Number of e-branchformer layers in the e-branchformer encoder + frontend: A pair of ConformerFrontend and corresponding config + block_cfg: Configuration for EbranchformerBlockV1 + """ + + num_layers: int + + # nested configurations + frontend: ModuleFactoryV1 + block_cfg: EbranchformerBlockV1Config + + +class EbranchformerEncoderV1(nn.Module): + """ + Implementation of the Branchformer with Enhanced merging (short e-branchformer), as in the original publication. + The model consists of a frontend and a stack of N e-branchformer blocks. + C.f. https://arxiv.org/pdf/2210.00077.pdf + """ + + def __init__(self, cfg: EbranchformerEncoderV1Config): + """ + :param cfg: e-branchformer encoder configuration with subunits for frontend and e-branchformer blocks + """ + super().__init__() + + self.frontend = cfg.frontend() + self.module_list = torch.nn.ModuleList([EbranchformerBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)]) + + def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param data_tensor: input tensor of shape [B, T', F'] + :param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T'] + :return: (output, out_seq_mask) + where output is torch.Tensor of shape [B, T, F], + out_seq_mask is a torch.Tensor of shape [B, T] + + F': input feature dim, F: internal and output feature dim + T': data time dim, T: down-sampled time dim (internal time dim) + """ + x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F] + for module in self.module_list: + x = module(x, sequence_mask) # [B, T, F] + + return x, sequence_mask diff --git a/i6_models/parts/e_branchformer/__init__.py b/i6_models/parts/e_branchformer/__init__.py new file mode 100644 index 00000000..986cb079 --- /dev/null +++ b/i6_models/parts/e_branchformer/__init__.py @@ -0,0 +1,2 @@ +from .cgmlp import * +from .merge import * diff --git a/i6_models/parts/e_branchformer/cgmlp.py b/i6_models/parts/e_branchformer/cgmlp.py new file mode 100644 index 00000000..9451e962 --- /dev/null +++ b/i6_models/parts/e_branchformer/cgmlp.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +__all__ = ["ConvolutionalGatingMLPV1Config", "ConvolutionalGatingMLPV1"] + +from dataclasses import dataclass +from typing import Callable + +import torch +from torch import nn + +from i6_models.config import ModelConfiguration + + +@dataclass +class ConvolutionalGatingMLPV1Config(ModelConfiguration): + """ + Attributes: + input_dim: input dimension + hidden_dim: hidden dimension (normally set to 6*input_dim as suggested by the paper) + kernel_size: kernel size of the depthwise convolution layer + dropout: dropout probability + activation: activation function + """ + + input_dim: int + hidden_dim: int + kernel_size: int + dropout: float + activation: Callable[[torch.Tensor], torch.Tensor] = nn.functional.gelu + + def check_valid(self): + assert self.kernel_size % 2 == 1, "ConvolutionalGatingMLPV1 only supports odd kernel sizes" + assert self.hidden_dim % 2 == 0, "ConvolutionalGatingMLPV1 only supports even hidden_dim" + + def __post__init__(self): + super().__post_init__() + self.check_valid() + + +class ConvolutionalGatingMLPV1(nn.Module): + """Convolutional Gating MLP (cgMLP).""" + + def __init__(self, model_cfg: ConvolutionalGatingMLPV1Config): + super().__init__() + + self.layer_norm_input = nn.LayerNorm(model_cfg.input_dim) + self.linear_ff = nn.Linear(in_features=model_cfg.input_dim, out_features=model_cfg.hidden_dim, bias=True) + self.activation = model_cfg.activation + self.layer_norm_csgu = nn.LayerNorm(model_cfg.hidden_dim // 2) + self.depthwise_conv = nn.Conv1d( + in_channels=model_cfg.hidden_dim // 2, + out_channels=model_cfg.hidden_dim // 2, + kernel_size=model_cfg.kernel_size, + padding=(model_cfg.kernel_size - 1) // 2, + groups=model_cfg.hidden_dim // 2, + ) + self.linear_out = nn.Linear(in_features=model_cfg.hidden_dim // 2, out_features=model_cfg.input_dim, bias=True) + self.dropout = model_cfg.dropout + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + :param x: shape [B, T, F], F=input_dim + :return: shape [B, T, F], F=input_dim + """ + x = self.layer_norm_input(x) # [B, T, F] + x = self.linear_ff(x) # [B, T, F'] + x = self.activation(x) + + # convolutional spatial gating unit (csgu) + x_1, x_2 = x.chunk(2, dim=-1) # [B, T, F'//2], [B, T, F'//2] + x_2 = self.layer_norm_csgu(x_2) + # conv layers expect shape [B, F, T] so we have to transpose here + x_2 = x_2.transpose(1, 2) # [B, F'//2, T] + x_2 = self.depthwise_conv(x_2) # [B, F'//2, T] + x_2 = x_2.transpose(1, 2) # [B, T, F'//2] + x = x_1 * x_2 # [B, T, F'//2] + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + + x = self.linear_out(x) # [B, T, F] + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + return x diff --git a/i6_models/parts/e_branchformer/merge.py b/i6_models/parts/e_branchformer/merge.py new file mode 100644 index 00000000..81c02a16 --- /dev/null +++ b/i6_models/parts/e_branchformer/merge.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +__all__ = ["MergerV1Config", "MergerV1"] + +from dataclasses import dataclass + +import torch +from torch import nn + +from i6_models.config import ModelConfiguration + + +@dataclass +class MergerV1Config(ModelConfiguration): + """ + Attributes: + input_dim: input dimension + kernel_size: kernel size of the depthwise convolution layer + dropout: dropout probability + """ + + input_dim: int + kernel_size: int + dropout: float + + def check_valid(self): + assert self.kernel_size % 2 == 1, "MergerV1 only supports odd kernel sizes" + + def __post__init__(self): + super().__post_init__() + self.check_valid() + + +class MergerV1(nn.Module): + def __init__(self, model_cfg: MergerV1Config): + """ + The merge module to merge the outputs of local extractor and global extractor + Here we take the best variant from the E-branchformer paper (Fig. 3c), refer to + https://arxiv.org/abs/2210.00077 for more merge module variants + """ + super().__init__() + + self.depthwise_conv = nn.Conv1d( + in_channels=model_cfg.input_dim * 2, + out_channels=model_cfg.input_dim * 2, + kernel_size=model_cfg.kernel_size, + padding=(model_cfg.kernel_size - 1) // 2, + groups=model_cfg.input_dim * 2, + ) + self.linear_ff = nn.Linear(in_features=2 * model_cfg.input_dim, out_features=model_cfg.input_dim, bias=True) + self.dropout = model_cfg.dropout + + def forward(self, x_1: torch.Tensor, x_2: torch.Tensor) -> torch.Tensor: + x_concat = torch.cat([x_1, x_2], dim=-1) # [B, T, 2F] + # conv layers expect shape [B, F, T] so we have to transpose here + x = x_concat.transpose(1, 2) # [B, 2F, T] + x = self.depthwise_conv(x) + x = x.transpose(1, 2) # [B, T, 2F] + x = x + x_concat + x = self.linear_ff(x) + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + return x diff --git a/tests/test_e_branchformer.py b/tests/test_e_branchformer.py new file mode 100644 index 00000000..f693a233 --- /dev/null +++ b/tests/test_e_branchformer.py @@ -0,0 +1,36 @@ +from itertools import product + +import torch +from torch import nn + +from i6_models.parts.e_branchformer.cgmlp import ConvolutionalGatingMLPV1Config, ConvolutionalGatingMLPV1 +from i6_models.parts.e_branchformer.merge import MergerV1Config, MergerV1 + + +def test_ConvolutionalGatingMLPV1(): + def get_output_shape(input_shape, hidden_dim, kernel_size, dropout, activation): + input_dim = input_shape[-1] + cfg = ConvolutionalGatingMLPV1Config(input_dim, hidden_dim, kernel_size, dropout, activation) + e_branchformer_cgmlp_part = ConvolutionalGatingMLPV1(cfg) + x = torch.randn(input_shape) + y = e_branchformer_cgmlp_part(x) + return y.shape + + for input_shape, hidden_dim, kernel_size, dropout, activation in product( + [(100, 5, 20), (200, 30, 10)], [120, 60], [9, 15], [0.1, 0.3], [nn.functional.gelu, nn.functional.relu] + ): + assert get_output_shape(input_shape, hidden_dim, kernel_size, dropout, activation) == input_shape + + +def test_MergerV1(): + def get_output_shape(input_shape, kernel_size, dropout): + input_dim = input_shape[-1] + cfg = MergerV1Config(input_dim, kernel_size, dropout) + e_branchformer_merge_part = MergerV1(cfg) + tensor_local = torch.randn(input_shape) + tensor_global = torch.randn(input_shape) + y = e_branchformer_merge_part(tensor_local, tensor_global) + return y.shape + + for input_shape, kernel_size, dropout in product([(100, 5, 20), (200, 30, 10)], [15, 31], [0.1, 0.3]): + assert get_output_shape(input_shape, kernel_size, dropout) == input_shape