Skip to content

Commit

Permalink
Remove apply functionality from BatchLinOp
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Oct 6, 2023
1 parent ff3c0de commit d2a4145
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 369 deletions.
185 changes: 1 addition & 184 deletions core/test/base/batch_lin_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,6 @@ class DummyBatchLinOp : public gko::batch::EnableBatchLinOp<DummyBatchLinOp>,
gko::batch_dim<2> size = gko::batch_dim<2>{})
: gko::batch::EnableBatchLinOp<DummyBatchLinOp>(exec, size)
{}

int called = 0;

protected:
void apply_impl(const gko::batch::BatchLinOp* b,
gko::batch::BatchLinOp* x) const override
{
this->called = 1;
}

void apply_impl(const gko::batch::BatchLinOp* alpha,
const gko::batch::BatchLinOp* b,
const gko::batch::BatchLinOp* beta,
gko::batch::BatchLinOp* x) const override
{
this->called = 2;
}
};


Expand All @@ -84,37 +67,13 @@ class EnableBatchLinOp : public ::testing::Test {
op{DummyBatchLinOp::create(ref2,
gko::batch_dim<2>(1, gko::dim<2>{3, 5}))},
op2{DummyBatchLinOp::create(ref2,
gko::batch_dim<2>(2, gko::dim<2>{3, 5}))},
alpha{DummyBatchLinOp::create(
ref, gko::batch_dim<2>(1, gko::dim<2>{1, 1}))},
alpha2{DummyBatchLinOp::create(
ref, gko::batch_dim<2>(2, gko::dim<2>{1, 1}))},
beta{DummyBatchLinOp::create(
ref, gko::batch_dim<2>(1, gko::dim<2>{1, 1}))},
beta2{DummyBatchLinOp::create(
ref, gko::batch_dim<2>(2, gko::dim<2>{1, 1}))},
b{DummyBatchLinOp::create(ref,
gko::batch_dim<2>(1, gko::dim<2>{5, 4}))},
b2{DummyBatchLinOp::create(ref,
gko::batch_dim<2>(2, gko::dim<2>{5, 4}))},
x{DummyBatchLinOp::create(ref,
gko::batch_dim<2>(1, gko::dim<2>{3, 4}))},
x2{DummyBatchLinOp::create(ref,
gko::batch_dim<2>(2, gko::dim<2>{3, 4}))}
gko::batch_dim<2>(2, gko::dim<2>{3, 5}))}
{}

std::shared_ptr<const gko::ReferenceExecutor> ref;
std::shared_ptr<const gko::ReferenceExecutor> ref2;
std::unique_ptr<DummyBatchLinOp> op;
std::unique_ptr<DummyBatchLinOp> op2;
std::unique_ptr<DummyBatchLinOp> alpha;
std::unique_ptr<DummyBatchLinOp> alpha2;
std::unique_ptr<DummyBatchLinOp> beta;
std::unique_ptr<DummyBatchLinOp> beta2;
std::unique_ptr<DummyBatchLinOp> b;
std::unique_ptr<DummyBatchLinOp> b2;
std::unique_ptr<DummyBatchLinOp> x;
std::unique_ptr<DummyBatchLinOp> x2;
};


Expand All @@ -134,137 +93,6 @@ TEST_F(EnableBatchLinOp, KnowsItsSizes)
}


TEST_F(EnableBatchLinOp, CallsApplyImpl)
{
op->apply(b, x);

ASSERT_EQ(op->called, 1);
}


TEST_F(EnableBatchLinOp, CallsApplyImplForBatch)
{
op2->apply(b2, x2);

ASSERT_EQ(op2->called, 1);
}


TEST_F(EnableBatchLinOp, CallsExtendedApplyImpl)
{
op->apply(alpha, b, beta, x);

ASSERT_EQ(op->called, 2);
}


TEST_F(EnableBatchLinOp, CallsExtendedApplyImplBatch)
{
op2->apply(alpha2, b2, beta2, x2);

ASSERT_EQ(op2->called, 2);
}


TEST_F(EnableBatchLinOp, ApplyFailsOnWrongBatchSize)
{
auto wrong =
DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 4}));

ASSERT_THROW(op->apply(wrong, x), gko::DimensionMismatch);
}


TEST_F(EnableBatchLinOp, ApplyFailsOnWrongNumBatchItems)
{
auto wrong =
DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 4}));

ASSERT_THROW(op2->apply(wrong, x2), gko::ValueMismatch);
}


TEST_F(EnableBatchLinOp, ApplyFailsOnWrongSolutionRows)
{
auto wrong =
DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{5, 4}));

ASSERT_THROW(op->apply(b, wrong), gko::DimensionMismatch);
}


TEST_F(EnableBatchLinOp, ApplyFailsOnOneBatchItemWrongSolutionRows)
{
auto wrong =
DummyBatchLinOp::create(ref, gko::batch_dim<2>(2, gko::dim<2>{5, 4}));

ASSERT_THROW(op2->apply(b2, wrong), gko::DimensionMismatch);
}


TEST_F(EnableBatchLinOp, ApplyFailsOnWrongSolutionColumns)
{
auto wrong =
DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 5}));

ASSERT_THROW(op->apply(b, wrong), gko::DimensionMismatch);
}


TEST_F(EnableBatchLinOp, ApplyFailsOnOneBatchItemWrongSolutionColumn)
{
auto wrong =
DummyBatchLinOp::create(ref, gko::batch_dim<2>(2, gko::dim<2>{3, 5}));

ASSERT_THROW(op2->apply(b2, wrong), gko::DimensionMismatch);
}


TEST_F(EnableBatchLinOp, ExtendedApplyFailsOnWrongBatchSize)
{
auto wrong =
DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 4}));

ASSERT_THROW(op->apply(alpha, wrong, beta, x), gko::DimensionMismatch);
}


TEST_F(EnableBatchLinOp, ExtendedApplyFailsOnWrongSolutionRows)
{
auto wrong =
DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{5, 4}));

ASSERT_THROW(op->apply(alpha, b, beta, wrong), gko::DimensionMismatch);
}


TEST_F(EnableBatchLinOp, ExtendedApplyFailsOnWrongSolutionColumns)
{
auto wrong =
DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 5}));

ASSERT_THROW(op->apply(alpha, b, beta, wrong), gko::DimensionMismatch);
}


TEST_F(EnableBatchLinOp, ExtendedApplyFailsOnWrongAlphaDimension)
{
auto wrong =
DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{2, 5}));

ASSERT_THROW(op->apply(wrong, b, beta, x), gko::DimensionMismatch);
}


TEST_F(EnableBatchLinOp, ExtendedApplyFailsOnWrongBetaDimension)
{
auto wrong =
DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{2, 5}));

ASSERT_THROW(op->apply(alpha, b, wrong, x), gko::DimensionMismatch);
}


template <typename T = int>
class DummyBatchLinOpWithFactory
: public gko::batch::EnableBatchLinOp<DummyBatchLinOpWithFactory<T>> {
Expand All @@ -290,17 +118,6 @@ class DummyBatchLinOpWithFactory
{}

std::shared_ptr<const gko::batch::BatchLinOp> op_;

protected:
void apply_impl(const gko::batch::BatchLinOp* b,
gko::batch::BatchLinOp* x) const override
{}

void apply_impl(const gko::batch::BatchLinOp* alpha,
const gko::batch::BatchLinOp* b,
const gko::batch::BatchLinOp* beta,
gko::batch::BatchLinOp* x) const override
{}
};


Expand Down
Loading

0 comments on commit d2a4145

Please sign in to comment.