diff --git a/core/test/base/batch_lin_op.cpp b/core/test/base/batch_lin_op.cpp index b656c2bf3fb..2e0bf0fae0e 100644 --- a/core/test/base/batch_lin_op.cpp +++ b/core/test/base/batch_lin_op.cpp @@ -56,23 +56,6 @@ class DummyBatchLinOp : public gko::batch::EnableBatchLinOp, gko::batch_dim<2> size = gko::batch_dim<2>{}) : gko::batch::EnableBatchLinOp(exec, size) {} - - int called = 0; - -protected: - void apply_impl(const gko::batch::BatchLinOp* b, - gko::batch::BatchLinOp* x) const override - { - this->called = 1; - } - - void apply_impl(const gko::batch::BatchLinOp* alpha, - const gko::batch::BatchLinOp* b, - const gko::batch::BatchLinOp* beta, - gko::batch::BatchLinOp* x) const override - { - this->called = 2; - } }; @@ -84,37 +67,13 @@ class EnableBatchLinOp : public ::testing::Test { op{DummyBatchLinOp::create(ref2, gko::batch_dim<2>(1, gko::dim<2>{3, 5}))}, op2{DummyBatchLinOp::create(ref2, - gko::batch_dim<2>(2, gko::dim<2>{3, 5}))}, - alpha{DummyBatchLinOp::create( - ref, gko::batch_dim<2>(1, gko::dim<2>{1, 1}))}, - alpha2{DummyBatchLinOp::create( - ref, gko::batch_dim<2>(2, gko::dim<2>{1, 1}))}, - beta{DummyBatchLinOp::create( - ref, gko::batch_dim<2>(1, gko::dim<2>{1, 1}))}, - beta2{DummyBatchLinOp::create( - ref, gko::batch_dim<2>(2, gko::dim<2>{1, 1}))}, - b{DummyBatchLinOp::create(ref, - gko::batch_dim<2>(1, gko::dim<2>{5, 4}))}, - b2{DummyBatchLinOp::create(ref, - gko::batch_dim<2>(2, gko::dim<2>{5, 4}))}, - x{DummyBatchLinOp::create(ref, - gko::batch_dim<2>(1, gko::dim<2>{3, 4}))}, - x2{DummyBatchLinOp::create(ref, - gko::batch_dim<2>(2, gko::dim<2>{3, 4}))} + gko::batch_dim<2>(2, gko::dim<2>{3, 5}))} {} std::shared_ptr ref; std::shared_ptr ref2; std::unique_ptr op; std::unique_ptr op2; - std::unique_ptr alpha; - std::unique_ptr alpha2; - std::unique_ptr beta; - std::unique_ptr beta2; - std::unique_ptr b; - std::unique_ptr b2; - std::unique_ptr x; - std::unique_ptr x2; }; @@ -134,137 +93,6 @@ TEST_F(EnableBatchLinOp, KnowsItsSizes) } -TEST_F(EnableBatchLinOp, CallsApplyImpl) -{ - op->apply(b, x); - - ASSERT_EQ(op->called, 1); -} - - -TEST_F(EnableBatchLinOp, CallsApplyImplForBatch) -{ - op2->apply(b2, x2); - - ASSERT_EQ(op2->called, 1); -} - - -TEST_F(EnableBatchLinOp, CallsExtendedApplyImpl) -{ - op->apply(alpha, b, beta, x); - - ASSERT_EQ(op->called, 2); -} - - -TEST_F(EnableBatchLinOp, CallsExtendedApplyImplBatch) -{ - op2->apply(alpha2, b2, beta2, x2); - - ASSERT_EQ(op2->called, 2); -} - - -TEST_F(EnableBatchLinOp, ApplyFailsOnWrongBatchSize) -{ - auto wrong = - DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 4})); - - ASSERT_THROW(op->apply(wrong, x), gko::DimensionMismatch); -} - - -TEST_F(EnableBatchLinOp, ApplyFailsOnWrongNumBatchItems) -{ - auto wrong = - DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 4})); - - ASSERT_THROW(op2->apply(wrong, x2), gko::ValueMismatch); -} - - -TEST_F(EnableBatchLinOp, ApplyFailsOnWrongSolutionRows) -{ - auto wrong = - DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{5, 4})); - - ASSERT_THROW(op->apply(b, wrong), gko::DimensionMismatch); -} - - -TEST_F(EnableBatchLinOp, ApplyFailsOnOneBatchItemWrongSolutionRows) -{ - auto wrong = - DummyBatchLinOp::create(ref, gko::batch_dim<2>(2, gko::dim<2>{5, 4})); - - ASSERT_THROW(op2->apply(b2, wrong), gko::DimensionMismatch); -} - - -TEST_F(EnableBatchLinOp, ApplyFailsOnWrongSolutionColumns) -{ - auto wrong = - DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 5})); - - ASSERT_THROW(op->apply(b, wrong), gko::DimensionMismatch); -} - - -TEST_F(EnableBatchLinOp, ApplyFailsOnOneBatchItemWrongSolutionColumn) -{ - auto wrong = - DummyBatchLinOp::create(ref, gko::batch_dim<2>(2, gko::dim<2>{3, 5})); - - ASSERT_THROW(op2->apply(b2, wrong), gko::DimensionMismatch); -} - - -TEST_F(EnableBatchLinOp, ExtendedApplyFailsOnWrongBatchSize) -{ - auto wrong = - DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 4})); - - ASSERT_THROW(op->apply(alpha, wrong, beta, x), gko::DimensionMismatch); -} - - -TEST_F(EnableBatchLinOp, ExtendedApplyFailsOnWrongSolutionRows) -{ - auto wrong = - DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{5, 4})); - - ASSERT_THROW(op->apply(alpha, b, beta, wrong), gko::DimensionMismatch); -} - - -TEST_F(EnableBatchLinOp, ExtendedApplyFailsOnWrongSolutionColumns) -{ - auto wrong = - DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 5})); - - ASSERT_THROW(op->apply(alpha, b, beta, wrong), gko::DimensionMismatch); -} - - -TEST_F(EnableBatchLinOp, ExtendedApplyFailsOnWrongAlphaDimension) -{ - auto wrong = - DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{2, 5})); - - ASSERT_THROW(op->apply(wrong, b, beta, x), gko::DimensionMismatch); -} - - -TEST_F(EnableBatchLinOp, ExtendedApplyFailsOnWrongBetaDimension) -{ - auto wrong = - DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{2, 5})); - - ASSERT_THROW(op->apply(alpha, b, wrong, x), gko::DimensionMismatch); -} - - template class DummyBatchLinOpWithFactory : public gko::batch::EnableBatchLinOp> { @@ -290,17 +118,6 @@ class DummyBatchLinOpWithFactory {} std::shared_ptr op_; - -protected: - void apply_impl(const gko::batch::BatchLinOp* b, - gko::batch::BatchLinOp* x) const override - {} - - void apply_impl(const gko::batch::BatchLinOp* alpha, - const gko::batch::BatchLinOp* b, - const gko::batch::BatchLinOp* beta, - gko::batch::BatchLinOp* x) const override - {} }; diff --git a/include/ginkgo/core/base/batch_lin_op.hpp b/include/ginkgo/core/base/batch_lin_op.hpp index a04ae3e79ce..ac632c715e8 100644 --- a/include/ginkgo/core/base/batch_lin_op.hpp +++ b/include/ginkgo/core/base/batch_lin_op.hpp @@ -91,99 +91,6 @@ namespace batch { */ class BatchLinOp : public EnableAbstractPolymorphicObject { public: - /** - * Applies a batch linear operator to a batch vector (or a sequence of batch - * of vectors). - * - * Performs the operation x = op(b), where op is this batch linear operator. - * - * @param b the input batch vector(s) on which the batch operator is - * applied - * @param x the output batch vector(s) where the result is stored - * - * @return this - */ - BatchLinOp* apply(ptr_param b, ptr_param x) - { - this->template log( - this, b.get(), x.get()); - this->validate_application_parameters(b.get(), x.get()); - auto exec = this->get_executor(); - this->apply_impl(make_temporary_clone(exec, b).get(), - make_temporary_clone(exec, x).get()); - this->template log( - this, b.get(), x.get()); - return this; - } - - /** - * @copydoc apply(const BatchLinOp *, BatchLinOp *) - */ - const BatchLinOp* apply(ptr_param b, - ptr_param x) const - { - this->template log( - this, b.get(), x.get()); - this->validate_application_parameters(b.get(), x.get()); - auto exec = this->get_executor(); - this->apply_impl(make_temporary_clone(exec, b).get(), - make_temporary_clone(exec, x).get()); - this->template log( - this, b.get(), x.get()); - return this; - } - - /** - * Performs the operation x = alpha * op(b) + beta * x. - * - * @param alpha scaling of the result of op(b) - * @param b vector(s) on which the operator is applied - * @param beta scaling of the input x - * @param x output vector(s) - * - * @return this - */ - BatchLinOp* apply(ptr_param alpha, - ptr_param b, - ptr_param beta, ptr_param x) - { - this->template log( - this, alpha.get(), b.get(), beta.get(), x.get()); - this->validate_application_parameters(alpha.get(), b.get(), beta.get(), - x.get()); - auto exec = this->get_executor(); - this->apply_impl(make_temporary_clone(exec, alpha).get(), - make_temporary_clone(exec, b).get(), - make_temporary_clone(exec, beta).get(), - make_temporary_clone(exec, x).get()); - this->template log( - this, alpha.get(), b.get(), beta.get(), x.get()); - return this; - } - - /** - * @copydoc apply(const BatchLinOp *, const BatchLinOp *, const BatchLinOp - * *, BatchLinOp *) - */ - const BatchLinOp* apply(ptr_param alpha, - ptr_param b, - ptr_param beta, - ptr_param x) const - { - this->template log( - this, alpha.get(), b.get(), beta.get(), x.get()); - this->validate_application_parameters(alpha.get(), b.get(), beta.get(), - x.get()); - auto exec = this->get_executor(); - this->apply_impl(make_temporary_clone(exec, alpha).get(), - make_temporary_clone(exec, b).get(), - make_temporary_clone(exec, beta).get(), - make_temporary_clone(exec, x).get()); - this->template log( - this, alpha.get(), b.get(), beta.get(), x.get()); - return this; - } - /** * Returns the number of batches in the batch operator. * @@ -236,66 +143,6 @@ class BatchLinOp : public EnableAbstractPolymorphicObject { : EnableAbstractPolymorphicObject(exec), size_{batch_size} {} - /** - * Implementers of BatchLinOp should override this function instead - * of apply(const BatchLinOp *, BatchLinOp *). - * - * Performs the operation x = op(b), where op is this linear operator. - * - * @param b the input batch vector(s) on which the operator is applied - * @param x the output batch vector(s) where the result is stored - */ - virtual void apply_impl(const BatchLinOp* b, BatchLinOp* x) const = 0; - - /** - * Implementers of BatchLinOp should override this function instead - * of apply(const BatchLinOp *, const BatchLinOp *, const BatchLinOp *, - * BatchLinOp *). - * - * @param alpha scaling of the result of op(b) - * @param b vector(s) on which the operator is applied - * @param beta scaling of the input x - * @param x output vector(s) - */ - virtual void apply_impl(const BatchLinOp* alpha, const BatchLinOp* b, - const BatchLinOp* beta, BatchLinOp* x) const = 0; - - /** - * Throws a DimensionMismatch exception if the parameters to `apply` are of - * the wrong size. - * - * @param b batch vector(s) on which the operator is applied - * @param x output batch vector(s) - */ - void validate_application_parameters(const BatchLinOp* b, - const BatchLinOp* x) const - { - GKO_ASSERT_BATCH_CONFORMANT(this, b); - GKO_ASSERT_BATCH_EQUAL_ROWS(this, x); - GKO_ASSERT_BATCH_EQUAL_COLS(b, x); - } - - /** - * Throws a DimensionMismatch exception if the parameters to `apply` are of - * the wrong size. - * - * @param alpha scaling of the result of op(b) - * @param b batch vector(s) on which the operator is applied - * @param beta scaling of the input x - * @param x output batch vector(s) - */ - void validate_application_parameters(const BatchLinOp* alpha, - const BatchLinOp* b, - const BatchLinOp* beta, - const BatchLinOp* x) const - { - this->validate_application_parameters(b, x); - GKO_ASSERT_BATCH_EQUAL_ROWS( - alpha, batch_dim<2>(b->get_num_batch_items(), dim<2>(1, 1))); - GKO_ASSERT_BATCH_EQUAL_ROWS( - beta, batch_dim<2>(b->get_num_batch_items(), dim<2>(1, 1))); - } - private: batch_dim<2> size_{}; }; @@ -395,38 +242,6 @@ class EnableBatchLinOp using EnablePolymorphicObject::EnablePolymorphicObject; - const ConcreteBatchLinOp* apply(ptr_param b, - ptr_param x) const - { - PolymorphicBase::apply(b, x); - return self(); - } - - ConcreteBatchLinOp* apply(ptr_param b, - ptr_param x) - { - PolymorphicBase::apply(b, x); - return self(); - } - - const ConcreteBatchLinOp* apply(ptr_param alpha, - ptr_param b, - ptr_param beta, - ptr_param x) const - { - PolymorphicBase::apply(alpha, b, beta, x); - return self(); - } - - ConcreteBatchLinOp* apply(ptr_param alpha, - ptr_param b, - ptr_param beta, - ptr_param x) - { - PolymorphicBase::apply(alpha, b, beta, x); - return self(); - } - protected: GKO_ENABLE_SELF(ConcreteBatchLinOp); };