-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add part: lstm block #66
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor Q.
i6_models/parts/lstm.py
Outdated
enforce_sorted: bool | ||
|
||
@classmethod | ||
def from_dict(cls, model_cfg_dict: Dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same Q as in the other PR: why is this necessary now, and hasn't been for the other assemblies?
Co-authored-by: Albert Zeyer <[email protected]>
if seq_len.get_device() >= 0: | ||
seq_len = seq_len.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if seq_len.get_device() >= 0: | |
seq_len = seq_len.cpu() | |
seq_len = seq_len.cpu() |
) | ||
|
||
def forward(self, x: torch.Tensor, seq_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only when not scripting? Don't you want that seq_len
is always on CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed the example in the blstm part.
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# during graph mode we have to assume all Tensors are on the correct device,
# otherwise move lengths to the CPU if they are on GPU
if seq_len.get_device() >= 0:
seq_len = seq_len.cpu()
I did not copy the comment over... I did not yet get to look why this is necessary
@JackTemaki you implemented the BLSTM IIRC. You remember why this was done in this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The question is, is this still relevant? This was something I added at some point, but if this is not needed for ONNX export this should be removed until there is actually a reason for it.
Co-authored-by: Albert Zeyer <[email protected]>
enforce_sorted: bool | ||
|
||
@classmethod | ||
def from_dict(cls, model_cfg_dict: Dict[str, Any]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see this for other part configs, why do we need this here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use it in the model definition to conert a dict to the config class. Need might be a bit strong but I like it :D
i6_models/parts/lstm.py
Outdated
num_layers: int | ||
bias: bool | ||
dropout: float | ||
bidirectional: bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we allow this? I feel like if we have bidirectional here, the BLSTM part becomes redundant, which is maybe okay, but might also cause two different branches that do the same, which I am not sure we want (if there are potential extensions later). We could maybe also just deprecate the BLSTM block?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point I ll just remove the flag. Is maybe a bit more readable having two classes?!
i6_models/parts/lstm.py
Outdated
class LstmBlockV1(nn.Module): | ||
def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict[str, Any]], **kwargs): | ||
""" | ||
Model definition of LSTM block. Contains single lstm stack and padding sequence in forward call. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add the "including dropout, batch-first variant, hardcoded to use B,T,F input" part please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could also add the supports scripting part.
bidirectional=self.cfg.bidirectional, | ||
) | ||
|
||
def forward(self, x: torch.Tensor, seq_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc pls.
i6_models/parts/lstm.py
Outdated
) | ||
|
||
lstm_out, _ = self.lstm_stack(lstm_packed_in) | ||
lstm_out, _ = nn.utils.rnn.pad_packed_sequence( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just out of curiosity: why does black force the new lines here but not for the blstm? Shouldnt it be the same line length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well this is dependent if you manually set the last commata. if set it will force new lines
Adds LSTM Block