diff --git a/modcc/CMakeLists.txt b/modcc/CMakeLists.txt index 6c4d92010a..c0a6537ee9 100644 --- a/modcc/CMakeLists.txt +++ b/modcc/CMakeLists.txt @@ -11,6 +11,7 @@ set(libmodcc_sources functioninliner.cpp lexer.cpp kineticrewriter.cpp + linearrewriter.cpp module.cpp parser.cpp solvers.cpp diff --git a/modcc/expression.cpp b/modcc/expression.cpp index 441ea63eb3..ef404d43fd 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -359,6 +359,25 @@ void StoichExpression::semantic(scope_ptr scp) { } } +/******************************************************************************* + LinearExpression +*******************************************************************************/ + +expression_ptr LinearExpression::clone() const { + return make_expression( + location_, lhs()->clone(), rhs()->clone()); +} + +void LinearExpression::semantic(scope_ptr scp) { + scope_ = scp; + lhs_->semantic(scp); + rhs_->semantic(scp); + + if(rhs_->is_procedure_call()) { + error("procedure calls can't be made in an expression"); + } +} + /******************************************************************************* ConserveExpression *******************************************************************************/ @@ -984,6 +1003,9 @@ void ConserveExpression::accept(Visitor *v) { void ReactionExpression::accept(Visitor *v) { v->visit(this); } +void LinearExpression::accept(Visitor *v) { + v->visit(this); +} void StoichExpression::accept(Visitor *v) { v->visit(this); } diff --git a/modcc/expression.hpp b/modcc/expression.hpp index 0613ddcfec..1860c7eea7 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -33,6 +33,7 @@ class BinaryExpression; class UnaryExpression; class AssignmentExpression; class ConserveExpression; +class LinearExpression; class ReactionExpression; class StoichExpression; class StoichTermExpression; @@ -79,7 +80,8 @@ enum class procedureKind { net_receive, ///< NET_RECEIVE breakpoint, ///< BREAKPOINT kinetic, ///< KINETIC - derivative ///< DERIVATIVE + derivative, ///< DERIVATIVE + linear, ///< LINEAR }; std::string to_string(procedureKind k); @@ -168,6 +170,7 @@ class Expression { virtual UnaryExpression* is_unary() {return nullptr;} virtual AssignmentExpression* is_assignment() {return nullptr;} virtual ConserveExpression* is_conserve() {return nullptr;} + virtual LinearExpression* is_linear() {return nullptr;} virtual ReactionExpression* is_reaction() {return nullptr;} virtual StoichExpression* is_stoich() {return nullptr;} virtual StoichTermExpression* is_stoich_term() {return nullptr;} @@ -1277,6 +1280,20 @@ class ConserveExpression : public BinaryExpression { void accept(Visitor *v) override; }; +class LinearExpression : public BinaryExpression { +public: + LinearExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs) + : BinaryExpression(loc, tok::eq, std::move(lhs), std::move(rhs)) + {} + + LinearExpression* is_linear() override {return this;} + expression_ptr clone() const override; + + void semantic(scope_ptr scp) override; + + void accept(Visitor *v) override; +}; + class AddBinaryExpression : public BinaryExpression { public: AddBinaryExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs) diff --git a/modcc/lexer.cpp b/modcc/lexer.cpp index a2bb90af8d..1eb684ff9f 100644 --- a/modcc/lexer.cpp +++ b/modcc/lexer.cpp @@ -234,6 +234,30 @@ Token Lexer::peek() { return t; } +bool Lexer::search_to_eol(tok const& t) { + // save the current position + const char *oldpos = current_; + const char *oldlin = line_; + Location oldloc = location_; + + Token p = token_; + bool ret = false; + while (line_ == oldlin && p.type != tok::eof) { + if (p.type == t) { + ret = true; + break; + } + p = parse(); + } + + // reset position + current_ = oldpos; + location_ = oldloc; + line_ = oldlin; + + return ret; +} + // scan floating point number from stream Token Lexer::number() { std::string str; diff --git a/modcc/lexer.hpp b/modcc/lexer.hpp index 142a5d3b1f..f897db5856 100644 --- a/modcc/lexer.hpp +++ b/modcc/lexer.hpp @@ -72,6 +72,9 @@ class Lexer { // return the next token in the stream without advancing the current position Token peek(); + // Look for `t` until new line or eof without advancing the current position, return true if found + bool search_to_eol(tok const& t); + // scan a number from the stream Token number(); diff --git a/modcc/linearrewriter.cpp b/modcc/linearrewriter.cpp new file mode 100644 index 0000000000..3705a2ed74 --- /dev/null +++ b/modcc/linearrewriter.cpp @@ -0,0 +1,88 @@ +#include +#include +#include +#include + +#include "astmanip.hpp" +#include "symdiff.hpp" +#include "visitor.hpp" + +class LinearRewriter : public BlockRewriterBase { +public: + using BlockRewriterBase::visit; + + LinearRewriter(std::vector st_vars): state_vars(st_vars) {} + LinearRewriter(scope_ptr enclosing_scope): BlockRewriterBase(enclosing_scope) {} + + virtual void visit(LinearExpression *e) override; + +protected: + virtual void reset() override { + BlockRewriterBase::reset(); + } + +private: + std::vector state_vars; +}; + +expression_ptr linear_rewrite(BlockExpression* block, std::vector state_vars) { + LinearRewriter visitor(state_vars); + block->accept(&visitor); + return visitor.as_block(false); +} + +// LinearRewriter implementation follows. + +// Factorize the linear expression in terms of the state variables and place +// the resulting sum of products on the lhs. Place everything else on the rhs +void LinearRewriter::visit(LinearExpression* e) { + Location loc = e->location(); + scope_ptr scope = e->scope(); + + expression_ptr lhs; + for (const auto& state : state_vars) { + // To factorize w.r.t state, differentiate the lhs and rhs + auto ident = make_expression(loc, state); + auto coeff = constant_simplify(make_expression(loc, + symbolic_pdiff(e->lhs(), state), + symbolic_pdiff(e->rhs(), state))); + + if (expr_value(coeff) != 0) { + auto local_coeff = make_unique_local_assign(scope, coeff, "l_"); + statements_.push_back(std::move(local_coeff.local_decl)); + statements_.push_back(std::move(local_coeff.assignment)); + + auto pair = make_expression(loc, std::move(local_coeff.id), std::move(ident)); + + // Construct the lhs of the new linear expression + if (!lhs) { + lhs = std::move(pair); + } else { + lhs = make_expression(loc, std::move(lhs), std::move(pair)); + } + } + } + + // To find the rhs of the new linear expression, simplify the old + // linear expression with state variables set to zero + auto rhs_0 = e->lhs()->clone(); + auto rhs_1 = e->rhs()->clone(); + + for (auto state: state_vars) { + auto zero_expr = make_expression(loc, 0.0); + rhs_0 = substitute(rhs_0, state, zero_expr); + rhs_1 = substitute(rhs_1, state, zero_expr); + } + rhs_0 = constant_simplify(rhs_0); + rhs_1 = constant_simplify(rhs_1); + + auto rhs = constant_simplify(make_expression(loc, std::move(rhs_1), std::move(rhs_0))); + + auto local_rhs = make_unique_local_assign(scope, rhs, "l_"); + statements_.push_back(std::move(local_rhs.local_decl)); + statements_.push_back(std::move(local_rhs.assignment)); + + rhs = std::move(local_rhs.id); + + statements_.push_back(make_expression(loc, std::move(lhs), std::move(rhs))); +} diff --git a/modcc/linearrewriter.hpp b/modcc/linearrewriter.hpp new file mode 100644 index 0000000000..32b3386334 --- /dev/null +++ b/modcc/linearrewriter.hpp @@ -0,0 +1,6 @@ +#pragma once + +#include "expression.hpp" + +// Translate a supplied LINEAR block. +expression_ptr linear_rewrite(BlockExpression*, std::vector); diff --git a/modcc/module.cpp b/modcc/module.cpp index 86a3ebec87..2df8ec16b7 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -11,6 +11,7 @@ #include "functionexpander.hpp" #include "functioninliner.hpp" #include "kineticrewriter.hpp" +#include "linearrewriter.hpp" #include "module.hpp" #include "parser.hpp" #include "solvers.hpp" @@ -279,7 +280,45 @@ bool Module::semantic() { auto& init_body = api_init->body()->statements(); for(auto& e : *proc_init->body()) { - init_body.emplace_back(e->clone()); + auto solve_expression = e->is_solve_statement(); + if (solve_expression) { + // Grab SOLVE statements, put them in `body` after translation. + std::set solved_ids; + std::unique_ptr solver = std::make_unique(); + + // The solve expression inside an initial block can only refer to a linear block + auto solve_proc = solve_expression->procedure(); + + if (solve_proc->kind() == procedureKind::linear) { + solver = std::make_unique(state_vars); + linear_rewrite(solve_proc->body(), state_vars)->accept(solver.get()); + } else { + error("A SOLVE expression in an INITIAL block can only be used to solve a LINEAR block, which" + + solve_expression->name() + "is not.", solve_expression->location()); + return false; + } + + if (auto solve_block = solver->as_block(false)) { + // Check that we didn't solve an already solved variable. + for (const auto &id: solver->solved_identifiers()) { + if (solved_ids.count(id) > 0) { + error("Variable " + id + " solved twice!", solve_expression->location()); + return false; + } + solved_ids.insert(id); + } + // Copy body into nrn_init. + for (auto &stmt: solve_block->is_block()->statements()) { + init_body.emplace_back(stmt->clone()); + } + } else { + // Something went wrong: copy errors across. + append_errors(solver->errors()); + return false; + } + } else { + init_body.emplace_back(e->clone()); + } } api_init->semantic(symbols_); @@ -337,6 +376,10 @@ bool Module::semantic() { if (deriv->kind()==procedureKind::kinetic) { kinetic_rewrite(deriv->body())->accept(solver.get()); } + else if (deriv->kind()==procedureKind::linear) { + solver = std::make_unique(state_vars); + linear_rewrite(deriv->body(), state_vars)->accept(solver.get()); + } else { deriv->body()->accept(solver.get()); for (auto& s: deriv->body()->statements()) { diff --git a/modcc/msparse.hpp b/modcc/msparse.hpp index b6e2292cf9..34e6866826 100644 --- a/modcc/msparse.hpp +++ b/modcc/msparse.hpp @@ -185,6 +185,12 @@ class matrix { bool empty() const { return size()==0; } bool augmented() const { return aug!=npos; } + void clear() { + rows.clear(); + cols = 0; + aug = row_npos; + } + // Add a column on the right as part of the augmented submatrix. // The new entries are provided by a (full, dense representation) // sequence of values. diff --git a/modcc/parser.cpp b/modcc/parser.cpp index 0e9e064bd5..0202d0b1bb 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -122,6 +122,7 @@ bool Parser::parse() { case tok::breakpoint : case tok::initial : case tok::kinetic : + case tok::linear : case tok::derivative : case tok::procedure : { @@ -905,6 +906,12 @@ symbol_ptr Parser::parse_procedure() { if( !expect( tok::identifier ) ) return nullptr; p = parse_prototype(); break; + case tok::linear: + kind = procedureKind::linear; + get_token(); // consume keyword token + if( !expect( tok::identifier ) ) return nullptr; + p = parse_prototype(); + break; case tok::procedure: kind = procedureKind::normal; get_token(); // consume keyword token @@ -994,7 +1001,7 @@ expression_ptr Parser::parse_statement() { case tok::conserve : return parse_conserve_expression(); case tok::tilde : - return parse_reaction_expression(); + return parse_tilde_expression(); case tok::initial : // only used for INITIAL block in NET_RECEIVE return parse_initial(); @@ -1183,76 +1190,94 @@ expression_ptr Parser::parse_stoich_expression() { return make_expression(here, std::move(terms)); } -expression_ptr Parser::parse_reaction_expression() { +expression_ptr Parser::parse_tilde_expression() { auto here = location_; if(token_.type!=tok::tilde) { error(pprintf("expected '%', found '%'", yellow("~"), yellow(token_.spelling))); return nullptr; } - get_token(); // consume tilde - expression_ptr lhs = parse_stoich_expression(); - if (!lhs) return nullptr; - // reaction halves must comprise non-negative terms - for (const auto& term: lhs->is_stoich()->terms()) { - // should always be true - if (auto sterm = term->is_stoich_term()) { - if (sterm->negative()) { - error(pprintf("expected only non-negative terms in reaction lhs, found '%'", - yellow(term->to_string()))); - return nullptr; + if (search_to_eol(tok::arrow)) { + expression_ptr lhs = parse_stoich_expression(); + if (!lhs) return nullptr; + + // reaction halves must comprise non-negative terms + for (const auto& term: lhs->is_stoich()->terms()) { + // should always be true + if (auto sterm = term->is_stoich_term()) { + if (sterm->negative()) { + error(pprintf("expected only non-negative terms in reaction lhs, found '%'", + yellow(term->to_string()))); + return nullptr; + } } } - } - if(token_.type != tok::arrow) { - error(pprintf("expected '%', found '%'", yellow("<->"), yellow(token_.spelling))); - return nullptr; - } + if(token_.type != tok::arrow) { + error(pprintf("expected '%', found '%'", yellow("<->"), yellow(token_.spelling))); + return nullptr; + } - get_token(); // consume arrow - expression_ptr rhs = parse_stoich_expression(); - if (!rhs) return nullptr; + get_token(); // consume arrow + expression_ptr rhs = parse_stoich_expression(); + if (!rhs) return nullptr; - for (const auto& term: rhs->is_stoich()->terms()) { - // should always be true - if (auto sterm = term->is_stoich_term()) { - if (sterm->negative()) { - error(pprintf("expected only non-negative terms in reaction rhs, found '%'", - yellow(term->to_string()))); - return nullptr; + for (const auto& term: rhs->is_stoich()->terms()) { + // should always be true + if (auto sterm = term->is_stoich_term()) { + if (sterm->negative()) { + error(pprintf("expected only non-negative terms in reaction rhs, found '%'", + yellow(term->to_string()))); + return nullptr; + } } } - } - if(token_.type != tok::lparen) { - error(pprintf("expected '%', found '%'", yellow("("), yellow(token_.spelling))); - return nullptr; - } + if (token_.type != tok::lparen) { + error(pprintf("expected '%', found '%'", yellow("("), yellow(token_.spelling))); + return nullptr; + } - get_token(); // consume lparen - expression_ptr fwd = parse_expression(); - if (!fwd) return nullptr; + get_token(); // consume lparen + expression_ptr fwd = parse_expression(); + if (!fwd) return nullptr; - if(token_.type != tok::comma) { - error(pprintf("expected '%', found '%'", yellow(","), yellow(token_.spelling))); - return nullptr; + if (token_.type != tok::comma) { + error(pprintf("expected '%', found '%'", yellow(","), yellow(token_.spelling))); + return nullptr; + } + + get_token(); // consume comma + expression_ptr rev = parse_expression(); + if (!rev) return nullptr; + + if (token_.type != tok::rparen) { + error(pprintf("expected '%', found '%'", yellow(")"), yellow(token_.spelling))); + return nullptr; + } + + get_token(); // consume rparen + return make_expression(here, std::move(lhs), std::move(rhs), + std::move(fwd), std::move(rev)); } + else if (search_to_eol(tok::eq)) { + auto lhs_bin = parse_expression(); - get_token(); // consume comma - expression_ptr rev = parse_expression(); - if (!rev) return nullptr; + if(token_.type!=tok::eq) { + error(pprintf("expected '%', found '%'", yellow("="), yellow(token_.spelling))); + return nullptr; + } - if(token_.type != tok::rparen) { - error(pprintf("expected '%', found '%'", yellow(")"), yellow(token_.spelling))); + get_token(); // consume = + auto rhs = parse_expression(); + return make_expression(here, std::move(lhs_bin), std::move(rhs)); + } + else { + error(pprintf("expected stoichiometric or linear expression, found neither")); return nullptr; } - - get_token(); // consume rparen - return make_expression(here, std::move(lhs), std::move(rhs), - std::move(fwd), std::move(rev)); } expression_ptr Parser::parse_conserve_expression() { @@ -1286,8 +1311,7 @@ expression_ptr Parser::parse_expression(int prec) { // Combine all sub-expressions with precedence greater than prec. for (;;) { if(token_.type==tok::eq) { - error("assignment '"+yellow("=")+"' not allowed in sub-expression"); - return nullptr; + return lhs; } auto op = token_; diff --git a/modcc/parser.hpp b/modcc/parser.hpp index 3c2a297a2c..ad841b57f9 100644 --- a/modcc/parser.hpp +++ b/modcc/parser.hpp @@ -28,7 +28,7 @@ class Parser : public Lexer { expression_ptr parse_line_expression(); expression_ptr parse_stoich_expression(); expression_ptr parse_stoich_term(); - expression_ptr parse_reaction_expression(); + expression_ptr parse_tilde_expression(); expression_ptr parse_conserve_expression(); expression_ptr parse_binop(expression_ptr&&, Token); expression_ptr parse_unaryop(); diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp index 96872a51e4..6f77091263 100644 --- a/modcc/solvers.cpp +++ b/modcc/solvers.cpp @@ -368,14 +368,108 @@ void SparseSolverVisitor::finalize() { // State variable updates given by rhs/diagonal for reduced matrix. Location loc; for (unsigned i = 0; i(loc, - make_expression(loc, dvars_[i]), + make_expression(loc, dvars_[lhs_col]), make_expression(loc, - make_expression(loc, symge::name(A_[i][rhs])), - make_expression(loc, symge::name(A_[i][i])))); + make_expression(loc, symge::name(A_[i][rhs_col])), + make_expression(loc, symge::name(A_[i][lhs_col])))); + + statements_.push_back(std::move(expr)); + } + + BlockRewriterBase::finalize(); +} + +void LinearSolverVisitor::visit(BlockExpression* e) { + BlockRewriterBase::visit(e); +} + +void LinearSolverVisitor::visit(AssignmentExpression *e) { + statements_.push_back(e->clone()); + return; +} + +void LinearSolverVisitor::visit(LinearExpression *e) { + auto loc = e->location(); + scope_ptr scope = e->scope(); + + if (A_.empty()) { + unsigned n = dvars_.size(); + A_ = symge::sym_matrix(n, n); + } + + linear_test_result r = linear_test(e->lhs(), dvars_); + if (!r.is_homogeneous) { + error({"System not homogeneous linear for sparse", loc}); + return; + } + + for (unsigned j = 0; jclone(); + } + + if (!expr) continue; + + auto a_ = expr->is_identifier()->spelling(); + + A_[deq_index_].push_back({j, symtbl_.define(a_)}); + } + rhs_.push_back(symtbl_.define(e->rhs()->is_identifier()->spelling())); + ++deq_index_; +} +void LinearSolverVisitor::finalize() { + A_.augment(rhs_); + + symge::gj_reduce(A_, symtbl_); + + // Create and assign intermediate variables. + for (unsigned i = 0; iis_identifier()->spelling(); + symtbl_.name(s, t_); + + statements_.push_back(std::move(local_t_term.local_decl)); + statements_.push_back(std::move(local_t_term.assignment)); + } + + // State variable updates given by rhs/diagonal for reduced matrix. + Location loc; + for (unsigned i = 0; i < A_.nrow(); ++i) { + const symge::sym_row& row = A_[i]; + unsigned rhs = A_.augcol(); + unsigned lhs; + for (unsigned r = 0; r < A_.nrow(); r++) { + if (row[r]) { + lhs = r; + break; + } + } + + auto expr = + make_expression(loc, + make_expression(loc, dvars_[lhs]), + make_expression(loc, + make_expression(loc, symge::name(A_[i][rhs])), + make_expression(loc, symge::name(A_[i][lhs])))); statements_.push_back(std::move(expr)); } diff --git a/modcc/solvers.hpp b/modcc/solvers.hpp index e31084dfc0..1198d1b104 100644 --- a/modcc/solvers.hpp +++ b/modcc/solvers.hpp @@ -101,10 +101,52 @@ class SparseSolverVisitor : public SolverVisitorBase { virtual void reset() override { deq_index_ = 0; local_expr_.clear(); + A_.clear(); symtbl_.clear(); + conserve_ = false; conserve_rhs_.clear(); conserve_idx_.clear(); - conserve_ = false; + SolverVisitorBase::reset(); + } +}; + +class LinearSolverVisitor : public SolverVisitorBase { +protected: + // 'Current' differential equation is for variable with this + // index in `dvars`. + unsigned deq_index_ = 0; + + // Expanded local assignments that need to be substituted in for derivative + // calculations. + substitute_map local_expr_; + + // Symbolic matrix for backwards Euler step. + symge::sym_matrix A_; + + // RHS + std::vector rhs_; + + // 'Symbol table' for symbolic manipulation. + symge::symbol_table symtbl_; + +public: + using SolverVisitorBase::visit; + + LinearSolverVisitor(std::vector vars) { + dvars_ = vars; + } + LinearSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {} + + virtual void visit(BlockExpression* e) override; + virtual void visit(LinearExpression *e) override; + virtual void visit(AssignmentExpression *e) override; + virtual void finalize() override; + virtual void reset() override { + deq_index_ = 0; + local_expr_.clear(); + A_.clear(); + rhs_.clear(); + symtbl_.clear(); SolverVisitorBase::reset(); } }; diff --git a/modcc/symge.cpp b/modcc/symge.cpp index 4f09a0226a..51a9008cd9 100644 --- a/modcc/symge.cpp +++ b/modcc/symge.cpp @@ -1,11 +1,18 @@ #include #include #include +#include #include "symge.hpp" namespace symge { +struct pivot { + unsigned row; + unsigned col; +}; + + // Returns q[c]*p - p[c]*q; new symbols required due to fill-in are provided by the // `define_sym` functor, which takes a `symbol_term_diff` and returns a `symbol`. @@ -45,7 +52,7 @@ sym_row row_reduce(unsigned c, const sym_row& p, const sym_row& q, DefineSym def // Estimate cost of a choice of pivot for G–J reduction below. Uses a simple greedy // estimate based on immediate fill cost. -double estimate_cost(const sym_matrix& A, unsigned p) { +double estimate_cost(const sym_matrix& A, pivot p) { unsigned nfill = 0; auto count_fill = [&nfill](symbol_term_diff t) { @@ -56,8 +63,8 @@ double estimate_cost(const sym_matrix& A, unsigned p) { }; for (unsigned i = 0; i pivots; - for (unsigned r = 0; r& remaining_rows) { + std::vector pivots; + for (auto r: remaining_rows) { + pivot p; + p.row = r; + const sym_row &row = A[r]; + for (unsigned c = 0; c < A.nrow(); ++c) { + if (row[c]) { + p.col = c; + break; + } + } + pivots.push_back(std::move(p)); + } + return pivots; + }; - std::vector cost(pivots.size()); + std::vector remaining_rows(A.nrow()); + std::iota(remaining_rows.begin(), remaining_rows.end(), 0); + + std::vector cost(A.nrow()); + + while (true) { + auto pivots = get_pivots(remaining_rows); - while (!pivots.empty()) { for (unsigned i = 0; icost[r2]; }); + [&](pivot r1, pivot r2) { return cost[r1.row]>cost[r2.row]; }); - unsigned pivrow = pivots.back(); - pivots.erase(std::prev(pivots.end())); - - unsigned pivcol = pivrow; + pivot p = pivots.back(); + remaining_rows.erase(std::lower_bound(remaining_rows.begin(), remaining_rows.end(), p.row)); for (unsigned i = 0; i s; - EXPECT_TRUE(check_parse(s, &Parser::parse_reaction_expression, text)); + EXPECT_TRUE(check_parse(s, &Parser::parse_tilde_expression, text)); } const char* bad_expr[] = { @@ -483,7 +483,7 @@ TEST(Parser, parse_reaction_expression) { }; for (auto& text: bad_expr) { - EXPECT_TRUE(check_parse_fail(&Parser::parse_reaction_expression, text)); + EXPECT_TRUE(check_parse_fail(&Parser::parse_tilde_expression, text)); } } diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 1e042fbb90..751647b42a 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -2,6 +2,9 @@ set(test_mechanisms celsius_test + test_linear_state + test_linear_init + test_linear_init_shuffle test0_kin_diff test0_kin_conserve test1_kin_diff @@ -83,7 +86,7 @@ set(unit_sources test_fvm_lowered.cpp test_glob_basic.cpp test_mc_cell_group.cpp - test_kinetic.cpp + test_kinetic_linear.cpp test_lexcmp.cpp test_lif_cell_group.cpp test_maputil.cpp diff --git a/test/unit/mod/test_linear_init.mod b/test/unit/mod/test_linear_init.mod new file mode 100644 index 0000000000..089f0453d0 --- /dev/null +++ b/test/unit/mod/test_linear_init.mod @@ -0,0 +1,33 @@ +NEURON { + SUFFIX test_linear_init + RANGE a0, a1, a3, a3 +} + +STATE { + s d h +} + +PARAMETER { + a4 = 7.3 +} + +ASSIGNED { + a0 : = 2.5 + a1 : = 0.5 + a2 : = 3 + a3 : = 2.3 +} + +BREAKPOINT { + s = a1 +} + +INITIAL { + SOLVE sinit +} + +LINEAR sinit { + ~ (a4 - a3)*d - a2*h = 0 + ~ (a0 + a1)*s - (-a1 + a0)*d = 0 + ~ s + d + h = 1 +} diff --git a/test/unit/mod/test_linear_init_shuffle.mod b/test/unit/mod/test_linear_init_shuffle.mod new file mode 100644 index 0000000000..07d4ce954b --- /dev/null +++ b/test/unit/mod/test_linear_init_shuffle.mod @@ -0,0 +1,33 @@ +NEURON { + SUFFIX test_linear_init_shuffle + RANGE a0, a1, a3, a3 +} + +STATE { + s d h +} + +PARAMETER { + a4 = 7.3 +} + +ASSIGNED { + a0 : = 2.5 + a1 : = 0.5 + a2 : = 3 + a3 : = 2.3 +} + +BREAKPOINT { + s = a1 +} + +INITIAL { + SOLVE sinit +} + +LINEAR sinit { + ~ a4*d - a3*d - a2*h = 0 + ~ a0*s - a0*d = - a1*s - a1*d + ~ s + d + h = 1 +} diff --git a/test/unit/mod/test_linear_state.mod b/test/unit/mod/test_linear_state.mod new file mode 100644 index 0000000000..1b0f823509 --- /dev/null +++ b/test/unit/mod/test_linear_state.mod @@ -0,0 +1,29 @@ +NEURON { + SUFFIX test_linear_state + RANGE a0, a1, a3, a3 +} + +STATE { + s d h +} + +PARAMETER { + a4 = 7.3 +} + +ASSIGNED { + a0 : = 2.5 + a1 : = 0.5 + a2 : = 3 + a3 : = 2.3 +} + +BREAKPOINT { + SOLVE sinit +} + +LINEAR sinit { + ~ (a4 - a3)*d - a2*h = 0 + ~ (a0 + a1)*s - (-a1 + a0)*d = 0 + ~ s + d + h = 1 +} diff --git a/test/unit/test_kinetic.cpp b/test/unit/test_kinetic.cpp deleted file mode 100644 index 71f1a636f0..0000000000 --- a/test/unit/test_kinetic.cpp +++ /dev/null @@ -1,117 +0,0 @@ -#include - -#include -#include - -#include "backends/multicore/fvm.hpp" - -#ifdef ARB_GPU_ENABLED -#include "backends/gpu/fvm.hpp" -#endif - -#include "common.hpp" -#include "mech_private_field_access.hpp" -#include "fvm_lowered_cell.hpp" -#include "fvm_lowered_cell_impl.hpp" -#include "sampler_map.hpp" -#include "simple_recipes.hpp" -#include "unit_test_catalogue.hpp" - -using namespace arb; - -using backend = arb::multicore::backend; -using fvm_cell = arb::fvm_lowered_cell_impl; - -using shared_state = backend::shared_state; -ACCESS_BIND(std::unique_ptr fvm_cell::*, private_state_ptr, &fvm_cell::state_) - -template -void run_kinetic_test(std::string mech_name, - std::vector variables, - std::vector t0_values, - std::vector t1_values) { - - auto cat = make_unit_test_catalogue(); - - fvm_size_type ncell = 1; - fvm_size_type ncv = 1; - std::vector cv_to_intdom(ncv, 0); - - std::vector gj = {}; - auto instance = cat.instance(mech_name); - auto& kinetic_test = instance.mech; - - std::vector temp(ncv, 300.); - std::vector vinit(ncv, -65); - - auto shared_state = std::make_unique( - ncell, cv_to_intdom, gj, vinit, temp, kinetic_test->data_alignment()); - - mechanism_layout layout; - mechanism_overrides overrides; - - layout.weight.assign(ncv, 1.); - for (fvm_size_type i = 0; iinstantiate(0, *shared_state, overrides, layout); - shared_state->reset(); - - kinetic_test->initialize(); - - for (unsigned i = 0; i < variables.size(); i++) { - for (unsigned j = 0; j < ncv; j++) { - EXPECT_NEAR(t0_values[i], mechanism_field(kinetic_test.get(), variables[i]).at(j), 1e-6); - } - } - - shared_state->update_time_to(0.5, 0.5); - shared_state->set_dt(); - - kinetic_test->nrn_state(); - - for (unsigned i = 0; i < variables.size(); i++) { - for (unsigned j = 0; j < ncv; j++) { - EXPECT_NEAR(t1_values[i], mechanism_field(kinetic_test.get(), variables[i]).at(j), 1e-6); - } - } -} - -TEST(mech_kinetic, kinetic_1_conserve) { - std::vector variables = {"s", "h", "d"}; - std::vector t0_values = {0.5, 0.2, 0.3}; - std::vector t1_values = {0.380338, 0.446414, 0.173247}; - - run_kinetic_test("test0_kin_diff", variables, t0_values, t1_values); - run_kinetic_test("test0_kin_conserve", variables, t0_values, t1_values); -} - -TEST(mech_kinetic, kinetic_2_conserve) { - std::vector variables = {"a", "b", "x", "y"}; - std::vector t0_values = {0.2, 0.8, 0.6, 0.4}; - std::vector t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; - - run_kinetic_test("test1_kin_diff", variables, t0_values, t1_values); - run_kinetic_test("test1_kin_conserve", variables, t0_values, t1_values); -} - -#ifdef ARB_GPU_ENABLED -TEST(mech_kinetic_gpu, kinetic_1_conserve) { - std::vector variables = {"s", "h", "d"}; - std::vector t0_values = {0.5, 0.2, 0.3}; - std::vector t1_values = {0.380338, 0.446414, 0.173247}; - - run_kinetic_test("test0_kin_diff", variables, t0_values, t1_values); - run_kinetic_test("test0_kin_conserve", variables, t0_values, t1_values); -} - -TEST(mech_kinetic_gpu, kinetic_2_conserve) { - std::vector variables = {"a", "b", "x", "y"}; - std::vector t0_values = {0.2, 0.8, 0.6, 0.4}; - std::vector t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; - - run_kinetic_test("test1_kin_diff", variables, t0_values, t1_values); - run_kinetic_test("test1_kin_conserve", variables, t0_values, t1_values); -} -#endif diff --git a/test/unit/test_kinetic_linear.cpp b/test/unit/test_kinetic_linear.cpp new file mode 100644 index 0000000000..16dddc4eda --- /dev/null +++ b/test/unit/test_kinetic_linear.cpp @@ -0,0 +1,148 @@ +#include + +#include +#include + +#include "backends/multicore/fvm.hpp" + +#ifdef ARB_GPU_ENABLED +#include "backends/gpu/fvm.hpp" +#endif + +#include "common.hpp" +#include "mech_private_field_access.hpp" +#include "fvm_lowered_cell.hpp" +#include "fvm_lowered_cell_impl.hpp" +#include "sampler_map.hpp" +#include "simple_recipes.hpp" +#include "unit_test_catalogue.hpp" + +using namespace arb; + +using backend = arb::multicore::backend; +using fvm_cell = arb::fvm_lowered_cell_impl; + +using shared_state = backend::shared_state; +ACCESS_BIND(std::unique_ptr fvm_cell::*, private_state_ptr, &fvm_cell::state_) + +template +void run_test(std::string mech_name, + std::vector state_variables, + std::unordered_map assigned_variables, + std::vector t0_values, + std::vector t1_values) { + + auto cat = make_unit_test_catalogue(); + + fvm_size_type ncell = 1; + fvm_size_type ncv = 1; + std::vector cv_to_intdom(ncv, 0); + + std::vector gj = {}; + auto instance = cat.instance(mech_name); + auto& test = instance.mech; + + std::vector temp(ncv, 300.); + std::vector vinit(ncv, -65); + + auto shared_state = std::make_unique( + ncell, cv_to_intdom, gj, vinit, temp, test->data_alignment()); + + mechanism_layout layout; + mechanism_overrides overrides; + + layout.weight.assign(ncv, 1.); + for (fvm_size_type i = 0; iinstantiate(0, *shared_state, overrides, layout); + + for (auto a: assigned_variables) { + test->set_parameter(a.first, std::vector(ncv,a.second)); + } + + shared_state->reset(); + + test->initialize(); + + if (!t0_values.empty()) { + for (unsigned i = 0; i < state_variables.size(); i++) { + for (unsigned j = 0; j < ncv; j++) { + EXPECT_NEAR(t0_values[i], mechanism_field(test.get(), state_variables[i]).at(j), 1e-6); + } + } + } + + shared_state->update_time_to(0.5, 0.5); + shared_state->set_dt(); + + test->nrn_state(); + + if (!t1_values.empty()) { + for (unsigned i = 0; i < state_variables.size(); i++) { + for (unsigned j = 0; j < ncv; j++) { + EXPECT_NEAR(t1_values[i], mechanism_field(test.get(), state_variables[i]).at(j), 1e-6); + } + } + } +} + +TEST(mech_kinetic, kintetic_1_conserve) { + std::vector state_variables = {"s", "h", "d"}; + std::vector t0_values = {0.5, 0.2, 0.3}; + std::vector t1_values = {0.380338, 0.446414, 0.173247}; + + run_test("test0_kin_diff", state_variables, {}, t0_values, t1_values); + run_test("test0_kin_conserve", state_variables, {}, t0_values, t1_values); +} + +TEST(mech_kinetic, kintetic_2_conserve) { + std::vector state_variables = {"a", "b", "x", "y"}; + std::vector t0_values = {0.2, 0.8, 0.6, 0.4}; + std::vector t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; + + run_test("test1_kin_diff", state_variables, {}, t0_values, t1_values); + run_test("test1_kin_conserve", state_variables, {}, t0_values, t1_values); +} + +TEST(mech_linear, linear) { + std::vector state_variables = {"h", "s", "d"}; + std::vector values = {0.5, 0.2, 0.3}; + std::unordered_map assigned_variables = {{"a0", 2.5}, {"a1",0.5}, {"a2",3}, {"a3",2.3}}; + + run_test("test_linear_state", state_variables, assigned_variables, {}, values); + run_test("test_linear_init", state_variables, assigned_variables, values, {}); + run_test("test_linear_init_shuffle", state_variables, assigned_variables, values, {}); +} + +#ifdef ARB_GPU_ENABLED +TEST(mech_kinetic_gpu, kintetic_1_conserve) { + std::vector state_variables = {"s", "h", "d"}; + std::vector t0_values = {0.5, 0.2, 0.3}; + std::vector t1_values = {0.380338, 0.446414, 0.173247}; + + run_test("test0_kin_diff", state_variables, {}, t0_values, t1_values); + run_test("test0_kin_conserve", state_variables, {}, t0_values, t1_values); +} + +TEST(mech_kinetic_gpu, kintetic_2_conserve) { + std::vector state_variables = {"a", "b", "x", "y"}; + std::vector t0_values = {0.2, 0.8, 0.6, 0.4}; + std::vector t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; + + run_test("test1_kin_diff", state_variables, {}, t0_values, t1_values); + run_test("test1_kin_conserve", state_variables, {}, t0_values, t1_values); +} + +TEST(mech_linear_gpu, linear) { + std::vector state_variables = {"h", "s", "d"}; + std::vector values = {0.5, 0.2, 0.3}; + std::unordered_map assigned_variables = {{"a0", 2.5},{"a1",0.5},{"a2",3},{"a3",2.3}}; + + run_test("test_linear_state", state_variables, assigned_variables, {}, values); + run_test("test_linear_init", state_variables, assigned_variables, values, {}); + run_test("test_linear_init_shuffle", state_variables, assigned_variables, values, {}); +} + +#endif diff --git a/test/unit/unit_test_catalogue.cpp b/test/unit/unit_test_catalogue.cpp index ccd007e359..5e9606c2a4 100644 --- a/test/unit/unit_test_catalogue.cpp +++ b/test/unit/unit_test_catalogue.cpp @@ -9,6 +9,9 @@ #include "unit_test_catalogue.hpp" #include "mechanisms/celsius_test.hpp" #include "mechanisms/test0_kin_diff.hpp" +#include "mechanisms/test_linear_state.hpp" +#include "mechanisms/test_linear_init.hpp" +#include "mechanisms/test_linear_init_shuffle.hpp" #include "mechanisms/test0_kin_conserve.hpp" #include "mechanisms/test1_kin_diff.hpp" #include "mechanisms/test1_kin_conserve.hpp" @@ -40,6 +43,9 @@ mechanism_catalogue make_unit_test_catalogue() { mechanism_catalogue cat; ADD_MECH(cat, celsius_test) + ADD_MECH(cat, test_linear_state) + ADD_MECH(cat, test_linear_init) + ADD_MECH(cat, test_linear_init_shuffle) ADD_MECH(cat, test0_kin_diff) ADD_MECH(cat, test0_kin_conserve) ADD_MECH(cat, test1_kin_diff)