Skip to content

Commit

Permalink
Stack translation fix for TorchFX
Browse files Browse the repository at this point in the history
  • Loading branch information
cavusmustafa committed Mar 1, 2024
1 parent 5e907a6 commit 3f9a245
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/frontends/pytorch/src/op/cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ OutputVector translate_quantized_cat(const NodeContext& context) {
};

OutputVector translate_stack_fx(const NodeContext& context) {
num_inputs_check(context, 2, context.get_input_size());
num_inputs_check(context, 1, context.get_input_size());
auto dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
std::deque<Output<Node>> list_elems;
auto num_elements = context.get_input_size();
Expand All @@ -112,14 +112,12 @@ OutputVector translate_stack_fx(const NodeContext& context) {
list_elems.push_back(stack_input);
}
int64_t axis = 0;
if (context.get_input_size() > 2)
axis = context.const_input<int64_t>(context.get_input_size() - 1);
if (!context.get_input_type(context.get_input_size() - 1).is<type::List>()) {
if (!context.get_input_type(num_elements - 1).is<type::List>()) {
// axis can be not present and that means that last input will have List type
axis = context.const_input<int64_t>(context.get_input_size() - 1);
axis = context.const_input<int64_t>(num_elements - 1);
} else {
auto stack_input =
context.mark_node(std::make_shared<v0::Unsqueeze>(context.get_input(static_cast<int>(context.get_input_size() - 1)), dim));
context.mark_node(std::make_shared<v0::Unsqueeze>(context.get_input(static_cast<int>(num_elements - 1)), dim));
list_elems.push_back(stack_input);
}
return translate_cat_common(context, list_elems, axis, true);
Expand Down

0 comments on commit 3f9a245

Please sign in to comment.