Skip to content

Commit

Permalink
modcc: parse and process LINEAR blocks (#840)
Browse files Browse the repository at this point in the history
Add support for parsing and processing `LINEAR` blocks: 

Changes: 
* `SOLVE` expressions can be called from inside an `INITIAL` block, but only if they are solving a linear system
* Tilde expressions can now be either linear expressions or reaction expressions
* Linear expressions need to be rewritten before being sent to the solver, this is done using `LinearRewriter`
* The linear system is setup in `LinearSolverVisitor` fills the lhs and rhs of the symbolic matrix   
* The matrix is recued using `gj_reduce`, which now works on non-diagonal matrices. 

Fixes #839
  • Loading branch information
noraabiakar authored and bcumming committed Aug 21, 2019
1 parent 1a07b81 commit 336c057
Show file tree
Hide file tree
Showing 26 changed files with 735 additions and 196 deletions.
1 change: 1 addition & 0 deletions modcc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ set(libmodcc_sources
functioninliner.cpp
lexer.cpp
kineticrewriter.cpp
linearrewriter.cpp
module.cpp
parser.cpp
solvers.cpp
Expand Down
22 changes: 22 additions & 0 deletions modcc/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,25 @@ void StoichExpression::semantic(scope_ptr scp) {
}
}

/*******************************************************************************
LinearExpression
*******************************************************************************/

expression_ptr LinearExpression::clone() const {
return make_expression<LinearExpression>(
location_, lhs()->clone(), rhs()->clone());
}

void LinearExpression::semantic(scope_ptr scp) {
scope_ = scp;
lhs_->semantic(scp);
rhs_->semantic(scp);

if(rhs_->is_procedure_call()) {
error("procedure calls can't be made in an expression");
}
}

/*******************************************************************************
ConserveExpression
*******************************************************************************/
Expand Down Expand Up @@ -984,6 +1003,9 @@ void ConserveExpression::accept(Visitor *v) {
void ReactionExpression::accept(Visitor *v) {
v->visit(this);
}
void LinearExpression::accept(Visitor *v) {
v->visit(this);
}
void StoichExpression::accept(Visitor *v) {
v->visit(this);
}
Expand Down
19 changes: 18 additions & 1 deletion modcc/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class BinaryExpression;
class UnaryExpression;
class AssignmentExpression;
class ConserveExpression;
class LinearExpression;
class ReactionExpression;
class StoichExpression;
class StoichTermExpression;
Expand Down Expand Up @@ -79,7 +80,8 @@ enum class procedureKind {
net_receive, ///< NET_RECEIVE
breakpoint, ///< BREAKPOINT
kinetic, ///< KINETIC
derivative ///< DERIVATIVE
derivative, ///< DERIVATIVE
linear, ///< LINEAR
};
std::string to_string(procedureKind k);

Expand Down Expand Up @@ -168,6 +170,7 @@ class Expression {
virtual UnaryExpression* is_unary() {return nullptr;}
virtual AssignmentExpression* is_assignment() {return nullptr;}
virtual ConserveExpression* is_conserve() {return nullptr;}
virtual LinearExpression* is_linear() {return nullptr;}
virtual ReactionExpression* is_reaction() {return nullptr;}
virtual StoichExpression* is_stoich() {return nullptr;}
virtual StoichTermExpression* is_stoich_term() {return nullptr;}
Expand Down Expand Up @@ -1277,6 +1280,20 @@ class ConserveExpression : public BinaryExpression {
void accept(Visitor *v) override;
};

class LinearExpression : public BinaryExpression {
public:
LinearExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs)
: BinaryExpression(loc, tok::eq, std::move(lhs), std::move(rhs))
{}

LinearExpression* is_linear() override {return this;}
expression_ptr clone() const override;

void semantic(scope_ptr scp) override;

void accept(Visitor *v) override;
};

class AddBinaryExpression : public BinaryExpression {
public:
AddBinaryExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs)
Expand Down
24 changes: 24 additions & 0 deletions modcc/lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,30 @@ Token Lexer::peek() {
return t;
}

bool Lexer::search_to_eol(tok const& t) {
// save the current position
const char *oldpos = current_;
const char *oldlin = line_;
Location oldloc = location_;

Token p = token_;
bool ret = false;
while (line_ == oldlin && p.type != tok::eof) {
if (p.type == t) {
ret = true;
break;
}
p = parse();
}

// reset position
current_ = oldpos;
location_ = oldloc;
line_ = oldlin;

return ret;
}

// scan floating point number from stream
Token Lexer::number() {
std::string str;
Expand Down
3 changes: 3 additions & 0 deletions modcc/lexer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class Lexer {
// return the next token in the stream without advancing the current position
Token peek();

// Look for `t` until new line or eof without advancing the current position, return true if found
bool search_to_eol(tok const& t);

// scan a number from the stream
Token number();

Expand Down
88 changes: 88 additions & 0 deletions modcc/linearrewriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include <iostream>
#include <map>
#include <string>
#include <list>

#include "astmanip.hpp"
#include "symdiff.hpp"
#include "visitor.hpp"

class LinearRewriter : public BlockRewriterBase {
public:
using BlockRewriterBase::visit;

LinearRewriter(std::vector<std::string> st_vars): state_vars(st_vars) {}
LinearRewriter(scope_ptr enclosing_scope): BlockRewriterBase(enclosing_scope) {}

virtual void visit(LinearExpression *e) override;

protected:
virtual void reset() override {
BlockRewriterBase::reset();
}

private:
std::vector<std::string> state_vars;
};

expression_ptr linear_rewrite(BlockExpression* block, std::vector<std::string> state_vars) {
LinearRewriter visitor(state_vars);
block->accept(&visitor);
return visitor.as_block(false);
}

// LinearRewriter implementation follows.

// Factorize the linear expression in terms of the state variables and place
// the resulting sum of products on the lhs. Place everything else on the rhs
void LinearRewriter::visit(LinearExpression* e) {
Location loc = e->location();
scope_ptr scope = e->scope();

expression_ptr lhs;
for (const auto& state : state_vars) {
// To factorize w.r.t state, differentiate the lhs and rhs
auto ident = make_expression<IdentifierExpression>(loc, state);
auto coeff = constant_simplify(make_expression<SubBinaryExpression>(loc,
symbolic_pdiff(e->lhs(), state),
symbolic_pdiff(e->rhs(), state)));

if (expr_value(coeff) != 0) {
auto local_coeff = make_unique_local_assign(scope, coeff, "l_");
statements_.push_back(std::move(local_coeff.local_decl));
statements_.push_back(std::move(local_coeff.assignment));

auto pair = make_expression<MulBinaryExpression>(loc, std::move(local_coeff.id), std::move(ident));

// Construct the lhs of the new linear expression
if (!lhs) {
lhs = std::move(pair);
} else {
lhs = make_expression<AddBinaryExpression>(loc, std::move(lhs), std::move(pair));
}
}
}

// To find the rhs of the new linear expression, simplify the old
// linear expression with state variables set to zero
auto rhs_0 = e->lhs()->clone();
auto rhs_1 = e->rhs()->clone();

for (auto state: state_vars) {
auto zero_expr = make_expression<NumberExpression>(loc, 0.0);
rhs_0 = substitute(rhs_0, state, zero_expr);
rhs_1 = substitute(rhs_1, state, zero_expr);
}
rhs_0 = constant_simplify(rhs_0);
rhs_1 = constant_simplify(rhs_1);

auto rhs = constant_simplify(make_expression<SubBinaryExpression>(loc, std::move(rhs_1), std::move(rhs_0)));

auto local_rhs = make_unique_local_assign(scope, rhs, "l_");
statements_.push_back(std::move(local_rhs.local_decl));
statements_.push_back(std::move(local_rhs.assignment));

rhs = std::move(local_rhs.id);

statements_.push_back(make_expression<LinearExpression>(loc, std::move(lhs), std::move(rhs)));
}
6 changes: 6 additions & 0 deletions modcc/linearrewriter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pragma once

#include "expression.hpp"

// Translate a supplied LINEAR block.
expression_ptr linear_rewrite(BlockExpression*, std::vector<std::string>);
45 changes: 44 additions & 1 deletion modcc/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "functionexpander.hpp"
#include "functioninliner.hpp"
#include "kineticrewriter.hpp"
#include "linearrewriter.hpp"
#include "module.hpp"
#include "parser.hpp"
#include "solvers.hpp"
Expand Down Expand Up @@ -279,7 +280,45 @@ bool Module::semantic() {
auto& init_body = api_init->body()->statements();

for(auto& e : *proc_init->body()) {
init_body.emplace_back(e->clone());
auto solve_expression = e->is_solve_statement();
if (solve_expression) {
// Grab SOLVE statements, put them in `body` after translation.
std::set<std::string> solved_ids;
std::unique_ptr<SolverVisitorBase> solver = std::make_unique<SparseSolverVisitor>();

// The solve expression inside an initial block can only refer to a linear block
auto solve_proc = solve_expression->procedure();

if (solve_proc->kind() == procedureKind::linear) {
solver = std::make_unique<LinearSolverVisitor>(state_vars);
linear_rewrite(solve_proc->body(), state_vars)->accept(solver.get());
} else {
error("A SOLVE expression in an INITIAL block can only be used to solve a LINEAR block, which" +
solve_expression->name() + "is not.", solve_expression->location());
return false;
}

if (auto solve_block = solver->as_block(false)) {
// Check that we didn't solve an already solved variable.
for (const auto &id: solver->solved_identifiers()) {
if (solved_ids.count(id) > 0) {
error("Variable " + id + " solved twice!", solve_expression->location());
return false;
}
solved_ids.insert(id);
}
// Copy body into nrn_init.
for (auto &stmt: solve_block->is_block()->statements()) {
init_body.emplace_back(stmt->clone());
}
} else {
// Something went wrong: copy errors across.
append_errors(solver->errors());
return false;
}
} else {
init_body.emplace_back(e->clone());
}
}

api_init->semantic(symbols_);
Expand Down Expand Up @@ -337,6 +376,10 @@ bool Module::semantic() {
if (deriv->kind()==procedureKind::kinetic) {
kinetic_rewrite(deriv->body())->accept(solver.get());
}
else if (deriv->kind()==procedureKind::linear) {
solver = std::make_unique<LinearSolverVisitor>(state_vars);
linear_rewrite(deriv->body(), state_vars)->accept(solver.get());
}
else {
deriv->body()->accept(solver.get());
for (auto& s: deriv->body()->statements()) {
Expand Down
6 changes: 6 additions & 0 deletions modcc/msparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ class matrix {
bool empty() const { return size()==0; }
bool augmented() const { return aug!=npos; }

void clear() {
rows.clear();
cols = 0;
aug = row_npos;
}

// Add a column on the right as part of the augmented submatrix.
// The new entries are provided by a (full, dense representation)
// sequence of values.
Expand Down
Loading

0 comments on commit 336c057

Please sign in to comment.