Skip to content

Commit

Permalink
Worked on evals
Browse files Browse the repository at this point in the history
  • Loading branch information
rayman2000 committed Jan 20, 2025
1 parent 97353cd commit c71de18
Show file tree
Hide file tree
Showing 43 changed files with 1,179 additions and 1,011 deletions.
75 changes: 50 additions & 25 deletions src/main/scala/biabduction/Abduction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package viper.silicon.biabduction

import viper.silicon
import viper.silicon.interfaces._
import viper.silicon.resources.{FieldID, PredicateID}
import viper.silicon.rules._
import viper.silicon.rules.chunkSupporter.findChunk
import viper.silicon.rules.evaluator.{eval, evals}
Expand All @@ -11,8 +10,8 @@ import viper.silicon.state._
import viper.silicon.state.terms.{SortWrapper, Term}
import viper.silicon.utils.freshSnap
import viper.silicon.verifier.Verifier
import viper.silver.ast
import viper.silver.ast._
import viper.silver.verifier.errors.Internal
import viper.silver.verifier.{BiAbductionQuestion, DummyReason}

object AbductionApplier extends RuleApplier[AbductionQuestion] {
Expand All @@ -21,8 +20,8 @@ object AbductionApplier extends RuleApplier[AbductionQuestion] {
}

case class AbductionQuestion(s: State, v: Verifier, goal: Seq[Exp],
lostAccesses: Map[Exp, Term] = Map(), foundState: Seq[Exp] = Seq(),
foundStmts: Seq[Stmt] = Seq(), trigger: Option[Positioned]) extends BiAbductionQuestion
lostAccesses: Map[Exp, Term] = Map(), foundState: Seq[(Exp, Option[BasicChunk])] = Seq(),
foundStmts: Seq[Stmt] = Seq(), trigger: Option[Positioned], stateAllowed: Boolean) extends BiAbductionQuestion

/**
* A rule for abduction. A rule is a pair of a check method and an apply method. The check method checks whether the
Expand Down Expand Up @@ -222,7 +221,7 @@ object AbductionFold extends AbductionRule {
//val pveTransformed = pve.withReasonNodeTransformed(reasonTransformer)

val fold = Fold(a)()
executor.exec(q.s, fold, q.v, doAbduction = true, q.trigger) { (s1, v1) =>
executor.exec(q.s, fold, q.v, q.trigger, q.stateAllowed) { (s1, v1) =>
val lost = q.lostAccesses + (field -> SortWrapper(chunk.snap, chunk.snap.sort))
val q2 = q.copy(s = s1, v = v1, foundStmts = q.foundStmts :+ fold, lostAccesses = lost, goal = g1)
Q(Some(q2))
Expand Down Expand Up @@ -324,7 +323,7 @@ object AbductionUnfold extends AbductionRule {
produces(q.s, freshSnap, conds, _ => pve, q.v)((s1, v1) => {
val wildcards = q.s.constrainableARPs -- q.s.constrainableARPs
predicateSupporter.unfold(s1, pred.loc(q.s.program), predChunk.args.toList, None, terms.FullPerm, None, wildcards, pve, v1, pred) { (s2, v2) =>
Q(Some(q.copy(s = s2, v = v2, foundStmts = q.foundStmts :+ Unfold(PredicateAccessPredicate(pred, FullPerm()())())(), foundState = q.foundState ++ conds)))
Q(Some(q.copy(s = s2, v = v2, foundStmts = q.foundStmts :+ Unfold(PredicateAccessPredicate(pred, FullPerm()())())(), foundState = q.foundState ++ conds.map(c => c -> None))))
}
})
}
Expand Down Expand Up @@ -374,7 +373,7 @@ object AbductionApply extends AbductionRule {
val stmts = q.foundStmts ++ lhsRes.foundStmts :+ Apply(wand)()
val state = q.foundState ++ lhsRes.foundState
val lost = q.lostAccesses ++ lhsRes.lostAccesses
Q(Some(AbductionQuestion(s2, v2, g1, lost, state, stmts, q.trigger)))
Q(Some(AbductionQuestion(s2, v2, g1, lost, state, stmts, q.trigger, q.stateAllowed)))
/*val g1 = q.goal.filterNot(_ == wand.right) :+ wand.left
consumer.consume(s1, wand, pve, v1)((s2, _, v2) =>
Q(Some(q.copy(s = s2, v = v2, goal = g1, foundStmts = q.foundStmts :+ Apply(wand)())))
Expand Down Expand Up @@ -402,23 +401,42 @@ object AbductionPackage extends AbductionRule {
q.goal.collectFirst { case a: MagicWand => a } match {
case None => Q(None)
case Some(wand) =>
producer.produce(q.s, freshSnap, wand.left, pve, q.v)((s1, v1) => {

val packQ = q.copy(s = s1, v = v1, goal = Seq(wand.right))
AbductionApplier.applyRules(packQ) { packRes =>

if (packRes.goal.nonEmpty) {
Q(None)
} else {

val g1 = q.goal.filterNot(_ == wand)
val stmts = q.foundStmts :+ Package(wand, Seqn(packRes.foundStmts.reverse, Seq())())()
val pres = q.foundState ++ packRes.foundState
//val lost = q.lostAccesses ++ packRes.lostAccesses
Q(Some(q.copy(s = packRes.s, v = packRes.v, goal = g1, foundStmts = stmts, foundState = pres)))
executionFlowController.locally(q.s, q.v) { (s0, v0) =>

// TODO This may produce things that are already in the state
producer.produce(s0, freshSnap, wand.left, pve, v0) { (s1, v1) =>
val packQ = q.copy(s = s1, v = v1, goal = Seq(wand.right))
AbductionApplier.applyRules(packQ) { packRes =>
if (packRes.goal.nonEmpty) {
Failure(pve dueTo(DummyReason))
//T(BiAbductionFailure(packRes.s, packRes.v, packRes.v.decider.pcs.duplicate()))
} else {
val newState = packRes.foundState
val newStmts = packRes.foundStmts
Success(Some(AbductionSuccess(packRes.s, packRes.v, packRes.v.decider.pcs.duplicate(), newState, newStmts, Map(), None)))
}
}
}
})
} match {
case _: FatalResult => Q(None)
case suc: NonFatalResult =>

val abdRes = abductionUtils.getAbductionSuccesses(suc)
val stmts = abdRes.flatMap(_.stmts) //.reverse?
val state = abdRes.flatMap(_.state).reverse

produces(q.s, freshSnap, state.map(_._1), _ => pve, q.v){ (s1, v1) =>
val script = Seqn(stmts, Seq())()
magicWandSupporter.packageWand(s1, wand, script, pve, v1) {
(s2, wandChunk, v2) =>
val g1 = q.goal.filterNot(_ == wand)
val finalStmts = q.foundStmts :+ Package(wand, script)()
val finalState = q.foundState ++ state
//val lost = q.lostAccesses ++ packRes.lostAccesses
Q(Some(q.copy(s = s2.copy(h = s2.reserveHeaps.head.+(wandChunk)), v = v2, goal = g1, foundStmts = finalStmts, foundState = finalState)))
}
}
}
}
}
}
Expand All @@ -430,13 +448,20 @@ object AbductionPackage extends AbductionRule {
object AbductionMissing extends AbductionRule {

override def apply(q: AbductionQuestion)(Q: Option[AbductionQuestion] => VerificationResult): VerificationResult = {
val accs = q.goal.collect { case e: AccessPredicate => e }
if (accs.isEmpty) {
val accs = q.goal.collect {
case e: FieldAccessPredicate => e
case e: PredicateAccessPredicate => e
}
if (!q.stateAllowed || accs.isEmpty) {
Q(None)
} else {
val g1 = q.goal.filterNot(accs.contains)
producer.produces(q.s, freshSnap, accs, _ => pve, q.v) { (s1, v1) =>
Q(Some(q.copy(s = s1, v = v1, goal = g1, foundState = q.foundState ++ accs)))
val locs: Map[LocationAccess, Exp] = accs.map {p => p.loc -> p}.toMap
abductionUtils.findChunks(locs.keys.toSeq, s1, v1, Internal()) { newChunks =>
val newState = newChunks.map {case (c, loc) => (locs(loc), Some(c))}
Q(Some(q.copy(s = s1, v = v1, goal = g1, foundState = q.foundState ++ newState)))
}
}
}
}
Expand Down
110 changes: 67 additions & 43 deletions src/main/scala/biabduction/Abstraction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,98 +14,122 @@ object AbstractionApplier extends RuleApplier[AbstractionQuestion] {
override val rules: Seq[AbstractionRule] = Seq(AbstractionFold, AbstractionPackage, AbstractionJoin, AbstractionApply)
}

case class AbstractionQuestion(s: State, v: Verifier, fixedChunks: Seq[Chunk]) {

// TODO we assume each field only appears in at most one predicate
def fields: Map[Field, Predicate] = s.program.predicates.flatMap { pred => pred.body.get.collect { case fa: FieldAccessPredicate => (fa.loc.field, pred) } }.toMap
case class AbstractionQuestion(s: State, v: Verifier) {

def varTran: VarTransformer = VarTransformer(s, v, s.g.values, s.h)

def isTriggerField(bc: BasicChunk): Boolean = {
bc.resourceID match {
case FieldID => fields.contains(abductionUtils.getField(bc.id, s.program)) && !fixedChunks.contains(bc)
case _ => false
}
}
}

trait AbstractionRule extends BiAbductionRule[AbstractionQuestion]

object AbstractionFold extends AbstractionRule {

private def checkChunks(chunks: Seq[BasicChunk], q: AbstractionQuestion)(Q: Option[AbstractionQuestion] => VerificationResult): VerificationResult = {

// TODO we assume each field only appears in at most one predicate
private def getFieldPredicate(bc: BasicChunk, q: AbstractionQuestion): Option[Predicate] = {

if (bc.resourceID != FieldID) None else {
val field = abductionUtils.getField(bc.id, q.s.program)

q.s.program.predicates.collectFirst { case pred if pred.collect{case fa: FieldAccess => fa.field}.toSeq.contains(field) => pred }
}
}

private def checkChunks(chunks: Seq[(BasicChunk, Predicate)], q: AbstractionQuestion)(Q: Option[AbstractionQuestion] => VerificationResult): VerificationResult = {
chunks match {
case Seq() => Q(None)
case chunk +: rest =>
val pred = q.fields(abductionUtils.getField(chunk.id, q.s.program))
case _ if chunks.isEmpty => Q(None)
case (chunk, pred) +: rest =>
//val pred = q.fields(abductionUtils.getField(chunk.id, q.s.program))
val wildcards = q.s.constrainableARPs -- q.s.constrainableARPs
executionFlowController.tryOrElse0(q.s, q.v) {
(s1, v1, T) =>

val fargs = pred.formalArgs.map(_.localVar)
//val fargs = pred.formalArgs.map(_.localVar)
val eArgs = q.varTran.transformTerm(chunk.args.head)
val formalsToActuals: Map[LocalVar, Exp] = fargs.zip(eArgs).to(Map)
val reasonTransformer = (n: viper.silver.verifier.errors.ErrorNode) => n.replace(formalsToActuals)
val pveTransformed = pve.withReasonNodeTransformed(reasonTransformer)
//val formalsToActuals: Map[LocalVar, Exp] = fargs.zip(eArgs).to(Map)
//val reasonTransformer = (n: viper.silver.verifier.errors.ErrorNode) => n.replace(formalsToActuals)
//val pveTransformed = pve.withReasonNodeTransformed(reasonTransformer)

val fold = Fold(PredicateAccessPredicate(PredicateAccess(Seq(eArgs.get), pred.name)(), FullPerm()())())()
executor.exec(s1, fold, v1, None, abdStateAllowed = false)(T)

// TODO nklose this can branch
predicateSupporter.fold(s1, pred, List(chunk.args.head), None, terms.FullPerm, Some(FullPerm()()), wildcards, pveTransformed, v1)(T)
//predicateSupporter.fold(s1, pred, List(chunk.args.head), None, terms.FullPerm, Some(FullPerm()()), wildcards, pveTransformed, v1)(T)
} {
(s2, v2) => Q(Some(q.copy(s = s2, v = v2)))
} {
f =>

f => checkChunks(rest, q)(Q)

/*
executionFlowController.tryOrElse0(q.s, q.v) {
(s3, v3, T) =>
BiAbductionSolver.solveAbduction(s3, v3, f, None) { (s4, res, v4) =>
res.state match {
case Seq() => T(s4, v4)
case _ => f
}
}
BiAbductionSolver.solveAbductionForError(s3, v3, f, stateAllowed = false, None) { T }
} {
(s5, v5) =>
Q(Some(q.copy(s = s5, v = v5)))
} {
f =>
checkChunks(rest, q)(Q)
}
}*/
}
}
}

override def apply(q: AbstractionQuestion)(Q: Option[AbstractionQuestion] => VerificationResult): VerificationResult = {
val candChunks = q.s.h.values.collect { case bc: BasicChunk if q.isTriggerField(bc) => bc }.toSeq
val candChunks = q.s.h.values.collect { case bc: BasicChunk => (bc, getFieldPredicate(bc, q)) }.collect { case (c, Some(pred)) => (c, pred) }.toSeq
checkChunks(candChunks, q)(Q)
}
}


object AbstractionPackage extends AbstractionRule {

// TODO nklose we should only trigger on fields for which there is a recursive predicate call
// TODO if the fold fails for a different reason than the recursive predicate missing, then this will do nonsense
// We should actually check what is missing and be smarter about what we package.
private def checkField(bc: BasicChunk, q: AbstractionQuestion): Option[MagicWand] = {
if (bc.resourceID != FieldID) None else {

// TODO this fails if the sorts don't line up
q.s.g.termValues.collectFirst{ case (lv, term) if term.sort == bc.snap.sort && q.v.decider.check(terms.Equals(term, bc.snap), Verifier.config.checkTimeout()) => lv} match {
case None => None
case Some(lhsArgExp) =>
val field = abductionUtils.getField(bc.id, q.s.program)
// TODO we assume each field only appears in at most one predicate
val predOpt = q.s.program.predicates.collectFirst { case pred if pred.collect{case fa: FieldAccess => fa.field}.toSeq.contains(field) => pred }
predOpt.flatMap { pred =>
val recPredOpt = pred.collectFirst {
case recPred@PredicateAccess(Seq(FieldAccess(_, field2)), _) if field == field2 => recPred
}
recPredOpt.flatMap { recPred =>
val lhs = PredicateAccessPredicate(PredicateAccess(Seq(lhsArgExp), recPred.predicateName)(NoPosition, NoInfo, NoTrafos), FullPerm()())()
val rhsArg = q.varTran.transformTerm(bc.args.head).get
val rhs = PredicateAccessPredicate(PredicateAccess(Seq(rhsArg), pred)(NoPosition, NoInfo, NoTrafos), FullPerm()())()
Some(MagicWand(lhs, rhs)())
}
}
}
}
}

@tailrec
private def findWandFieldChunk(chunks: Seq[Chunk], q: AbstractionQuestion): Option[(Exp, BasicChunk)] = {
private def findWand(chunks: Seq[Chunk], q: AbstractionQuestion): Option[MagicWand] = {
chunks match {
case Seq() => None
case (chunk: BasicChunk) +: rest if q.isTriggerField(chunk) =>
q.varTran.transformTerm(chunk.snap) match {
case None => findWandFieldChunk(rest, q)
case Some(snap) => Some((snap, chunk))
case (chunk: BasicChunk) +: rest =>
checkField(chunk, q) match {
case None => findWand(rest, q)
case wand => wand
}
case _ +: rest => findWandFieldChunk(rest, q)
case (_: Chunk) +: rest => findWand(rest, q)
}
}

override def apply(q: AbstractionQuestion)(Q: Option[AbstractionQuestion] => VerificationResult): VerificationResult = {

findWandFieldChunk(q.s.h.values.toSeq, q) match {
findWand(q.s.h.values.toSeq, q) match {
case None => Q(None)
case Some((lhsArg, chunk)) =>
val pred = q.fields(abductionUtils.getField(chunk.id, q.s.program))
val lhs = PredicateAccessPredicate(PredicateAccess(Seq(lhsArg), pred)(NoPosition, NoInfo, NoTrafos), FullPerm()())()
val rhsArg = q.varTran.transformTerm(chunk.args.head).get
val rhs = PredicateAccessPredicate(PredicateAccess(Seq(rhsArg), pred)(NoPosition, NoInfo, NoTrafos), FullPerm()())()
val wand = MagicWand(lhs, rhs)()
case Some(wand) =>
executor.exec(q.s, Assert(wand)(), q.v) {
(s1, v1) =>
Q(Some(q.copy(s = s1, v = v1)))
Expand Down Expand Up @@ -137,7 +161,7 @@ object AbstractionApply extends AbstractionRule {

override def apply(q: AbstractionQuestion)(Q: Option[AbstractionQuestion] => VerificationResult): VerificationResult = {
val wands = q.s.h.values.collect { case wand: MagicWandChunk => q.varTran.transformChunk(wand) }.collect { case Some(wand: MagicWand) => wand }
val targets = q.s.h.values.collect { case c: BasicChunk if !q.fixedChunks.contains(c) => q.varTran.transformChunk(c) }.collect { case Some(exp) => exp }.toSeq
val targets = q.s.h.values.collect { case c: BasicChunk => q.varTran.transformChunk(c) }.collect { case Some(exp) => exp }.toSeq

wands.collectFirst { case wand if targets.contains(wand.left) => wand } match {
case None => Q(None)
Expand Down
Loading

0 comments on commit c71de18

Please sign in to comment.