From 21115944806a65ddf74ce7ee2e4889a0624855a1 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sat, 27 Jan 2024 07:01:41 +1100 Subject: [PATCH] Track whether or not let expressions failed to solve in solver (#7982) * Track whether or not let expressions failed to solve in solver After mutating an expression, the solver needs to know two things: 1) Did the expression contain the variable we're solving for 2) Was the expression successfully "solved" for the variable. I.e. the variable only appears once in the leftmost position. We need to know this to know property 1 of any subexpressions (i.e. does the right child of the expression contain the variable). This drives what transformations we do in ways that are guaranteed to terminate and not take exponential time. We were tracking property 1 through lets but not property 2, and this meant we were doing unhelpful transformations in some cases. I found a case in the wild where this made a pipeline take > 1 hour to compile (I killed it after an hour). It may have been in an infinite transformation loop, or it might have just been exponential. Not sure. * Remove surplus comma * Fix use of uninitialized value that could cause bad transformation --- src/ModulusRemainder.h | 6 ++++-- src/Solve.cpp | 35 ++++++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/ModulusRemainder.h b/src/ModulusRemainder.h index c0341b75abf6..cbcdce10b98c 100644 --- a/src/ModulusRemainder.h +++ b/src/ModulusRemainder.h @@ -7,6 +7,8 @@ #include +#include "Util.h" + namespace Halide { struct Expr; @@ -83,8 +85,8 @@ ModulusRemainder modulus_remainder(const Expr &e, const Scope /** Reduce an expression modulo some integer. Returns true and assigns * to remainder if an answer could be found. */ ///@{ -bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder); -bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder, const Scope &scope); +HALIDE_MUST_USE_RESULT bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder); +HALIDE_MUST_USE_RESULT bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder, const Scope &scope); ///@} void modulus_remainder_test(); diff --git a/src/Solve.cpp b/src/Solve.cpp index a08eedadbd27..22bd14e44412 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -44,18 +44,22 @@ class SolveExpression : public IRMutator { map::iterator iter = cache.find(e); if (iter == cache.end()) { // Not in the cache, call the base class version. - debug(4) << "Mutating " << e << " (" << uses_var << ")\n"; + debug(4) << "Mutating " << e << " (" << uses_var << ", " << failed << ")\n"; bool old_uses_var = uses_var; uses_var = false; + bool old_failed = failed; + failed = false; Expr new_e = IRMutator::mutate(e); - CacheEntry entry = {new_e, uses_var}; + CacheEntry entry = {new_e, uses_var, failed}; uses_var = old_uses_var || uses_var; + failed = old_failed || failed; cache[e] = entry; - debug(4) << "(Miss) Rewrote " << e << " -> " << new_e << " (" << uses_var << ")\n"; + debug(4) << "(Miss) Rewrote " << e << " -> " << new_e << " (" << uses_var << ", " << failed << ")\n"; return new_e; } else { // Cache hit. uses_var = uses_var || iter->second.uses_var; + failed = failed || iter->second.failed; debug(4) << "(Hit) Rewrote " << e << " -> " << iter->second.expr << " (" << uses_var << ")\n"; return iter->second.expr; } @@ -75,7 +79,7 @@ class SolveExpression : public IRMutator { // stateless, so we can cache everything. struct CacheEntry { Expr expr; - bool uses_var; + bool uses_var, failed; }; map cache; @@ -388,16 +392,25 @@ class SolveExpression : public IRMutator { const Mul *mul_a = a.as(); Expr expr; if (a_uses_var && !b_uses_var) { + const int64_t *ib = as_const_int(b); + auto is_multiple_of_b = [&](const Expr &e) { + if (ib) { + int64_t r = 0; + return reduce_expr_modulo(e, *ib, &r) && r == 0; + } else { + return can_prove(e / b * b == e); + } + }; if (add_a && !a_failed && - can_prove(add_a->a / b * b == add_a->a)) { + is_multiple_of_b(add_a->a)) { // (f(x) + a) / b -> f(x) / b + a / b expr = mutate(simplify(add_a->a / b) + add_a->b / b); } else if (sub_a && !a_failed && - can_prove(sub_a->a / b * b == sub_a->a)) { + is_multiple_of_b(sub_a->a)) { // (f(x) - a) / b -> f(x) / b - a / b expr = mutate(simplify(sub_a->a / b) - sub_a->b / b); } else if (mul_a && !a_failed && no_overflow_int(op->type) && - can_prove(mul_a->b / b * b == mul_a->b)) { + is_multiple_of_b(mul_a->b)) { // (f(x) * a) / b -> f(x) * (a / b) expr = mutate(mul_a->a * (mul_a->b / b)); } @@ -776,6 +789,7 @@ class SolveExpression : public IRMutator { } else if (scope.contains(op->name)) { CacheEntry e = scope.get(op->name); uses_var = uses_var || e.uses_var; + failed = failed || e.failed; return e.expr; } else if (external_scope.contains(op->name)) { Expr e = external_scope.get(op->name); @@ -790,11 +804,14 @@ class SolveExpression : public IRMutator { Expr visit(const Let *op) override { bool old_uses_var = uses_var; + bool old_failed = failed; uses_var = false; + failed = false; Expr value = mutate(op->value); - CacheEntry e = {value, uses_var}; - + CacheEntry e = {value, uses_var, failed}; uses_var = old_uses_var; + failed = old_failed; + ScopedBinding bind(scope, op->name, e); return mutate(op->body); }