Skip to content

Commit

Permalink
fix return type for convolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Oct 13, 2023
1 parent 464fa13 commit a6ae02e
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 19 deletions.
8 changes: 4 additions & 4 deletions torch_frame/nn/conv/ft_transformer_convs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
x (Tensor): Input tensor of shape [batch_size, num_cols, channels]
Returns:
x (Tensor): Output tensor of shape [batch_size, num_cols, channels]
corresponding to the input columns.
x_cls (Tensor): Output tensor of shape [batch_size, channels],
corresponding to the added CLS token column.
(torch.Tensor, torch.Tensor): (Output tensor of shape
[batch_size, num_cols, channels] corresponding to the input
columns, Output tensor of shape [batch_size, channels],
corresponding to the added CLS token column.)
"""
B, _, _ = x.shape
# [batch_size, num_cols, channels]
Expand Down
2 changes: 1 addition & 1 deletion torch_frame/nn/conv/table_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Any:
3-dimensional tensor.
Args:
x (Tensor): Input column-wise tensor of shape
x (torch.Tensor): Input column-wise tensor of shape
:obj:`[batch_size, num_cols, hidden_channels]`.
args (Any): Extra arguments.
kwargs (Any): Extra keyward arguments.
Expand Down
6 changes: 3 additions & 3 deletions torch_frame/nn/conv/trompt_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def forward(self, x: Tensor, x_prompt: Tensor) -> Tensor:
the next layer.
Args:
x (Tensor): Feature-based embedding of shape
x (torch.Tensor): Feature-based embedding of shape
[batch_size, num_cols, channels]
x_prompt (Tensor): Input prompt embeddings of shape
x_prompt (torch.Tensor): Input prompt embeddings of shape
[batch_size, num_prompts, channels]
Returns:
x_prompt (Tensor): Output prompt embeddings for the next layer. The
torch.Tensor: Output prompt embeddings for the next layer. The
shape is [batch_size, num_prompts, channels].
"""
batch_size = len(x)
Expand Down
20 changes: 10 additions & 10 deletions torch_frame/nn/models/excelformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ class ExcelFormer(Module):
names are sorted based on the ordering that appear in
:obj:`tensor_frame.feat_dict`. Available as
:obj:`tensor_frame.col_names_dict`.
diam_dropout (float, optional): diam_dropout (default: :obj:`0.0`)
aium_dropout (float, optional): aium_dropout (default: :obj:`0.0`)
residual_dropout (float, optional): residual dropout (default: `0.0`)
diam_dropout (float, optional): diam_dropout. (default: :obj:`0.0`)
aium_dropout (float, optional): aium_dropout. (default: :obj:`0.0`)
residual_dropout (float, optional): residual dropout.
(default: :obj:`0.0`)
"""
def __init__(
self,
Expand Down Expand Up @@ -177,15 +178,14 @@ def forward_mixup(
tf (TensorFrame): Input :obj:`TensorFrame` object.
beta (float, optional): Shape parameter for beta distribution to
calculate shuffle rate in mixup. Only useful when mixup is
true. (default: 0.5)
true. (default: :obj:`0.5`)
Returns:
torch.Tensor: The mixed up output embeddings of size
[batch_size, out_channels].
torch.Tensor: Output :obj:`Tensor` y_mixedup will be
returned only when mixup is set to true. The size is
[batch_size, num_classes] for classification and
[batch_size, 1] for regression.
(torch.Tensor, torch.Tensor): The first :obj:`Tensor` is the mixed
up output embeddings of size [batch_size, out_channels].
The second :obj:`Tensor` y_mixedup will be returned only when
mixup is set to true. The size is [batch_size, num_classes] for
classification and [batch_size, 1] for regression.
"""
# Mixup numerical features
x_mixedup, y_mixedup = feature_mixup(
Expand Down
2 changes: 1 addition & 1 deletion torch_frame/nn/models/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def forward(self, tf: TensorFrame) -> Tensor:
r"""Transforming :obj:`TensorFrame` object into output prediction.
Args:
x (Tensor): Input :obj:`TensorFrame` object.
x (torch.Tensor): Input :obj:`TensorFrame` object.
Returns:
torch.Tensor: Output of shape [batch_size, out_channels].
Expand Down

0 comments on commit a6ae02e

Please sign in to comment.