Skip to content

Commit

Permalink
defines '==' for ghost structs to be the same as '==='
Browse files Browse the repository at this point in the history
  • Loading branch information
ArquintL committed Jan 24, 2025
1 parent 5f77eee commit 0e6da08
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 22 deletions.
14 changes: 3 additions & 11 deletions src/main/scala/viper/gobra/ast/internal/Program.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1386,14 +1386,6 @@ case class SliceT(elems : Type, addressability: Addressability) extends PrettyTy
* The (composite) type of maps from type `keys` to type `values`.
*/
case class MapT(keys: Type, values: Type, addressability: Addressability) extends PrettyType(s"map[$keys]$values") {
def hasGhostField(k: Type): Boolean = k match {
case StructT(fields, _) => fields exists (_.ghost)
case _ => false
}
// this check must be done here instead of at the type system level because the concrete AST does not support
// ghost fields yet
require(!hasGhostField(keys))

override def equalsWithoutMod(t: Type): Boolean = t match {
case MapT(otherKeys, otherValues, _) => keys.equalsWithoutMod(otherKeys) && values.equalsWithoutMod(otherValues)
case _ => false
Expand Down Expand Up @@ -1512,14 +1504,14 @@ case class PredT(args: Vector[Type], addressability: Addressability) extends Pre

// StructT does not have a name because equality of two StructT does not depend at all on their declaration site but
// only on their structure, i.e. whether the fields (and addressability) are equal
case class StructT(fields: Vector[Field], addressability: Addressability) extends PrettyType(fields.mkString("struct{", ", ", "}")) with TopType {
case class StructT(fields: Vector[Field], ghost: Boolean, addressability: Addressability) extends PrettyType(fields.mkString(if (ghost) "ghost " else "" + "struct{", ", ", "}")) with TopType {
override def equalsWithoutMod(t: Type): Boolean = t match {
case StructT(otherFields, _) => fields.zip(otherFields).forall{ case (l, r) => l.typ.equalsWithoutMod(r.typ) }
case StructT(otherFields, otherGhost, _) => ghost == otherGhost && fields.zip(otherFields).forall{ case (l, r) => l.name == r.name && l.ghost == r.ghost && l.typ.equalsWithoutMod(r.typ) }
case _ => false
}

override def withAddressability(newAddressability: Addressability): StructT =
StructT(fields.map(f => Field(f.name, f.typ.withAddressability(Addressability.field(newAddressability)), f.ghost)(f.info)), newAddressability)
StructT(fields.map(f => Field(f.name, f.typ.withAddressability(Addressability.field(newAddressability)), f.ghost)(f.info)), ghost = ghost, newAddressability)
}

case class InterfaceT(name: String, addressability: Addressability) extends PrettyType(s"interface{ name is $name }") with TopType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ object CGEdgesTerminationTransform extends InternalTransform {
case in.DefinedT(name, _) => in.DefinedTExpr(name)(src)
case in.PointerT(t, _) => in.PointerTExpr(typeAsExpr(t)(src))(src)
case in.TupleT(ts, _) => in.TupleTExpr(ts map(typeAsExpr(_)(src)))(src)
case in.StructT(fields: Vector[in.Field], _) =>
case in.StructT(fields: Vector[in.Field], _, _) =>
in.StructTExpr(fields.map(field => (field.name, typeAsExpr(field.typ)(src), field.ghost)))(src)
case _ => Violation.violation(s"no corresponding type expression matched: $t")
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/viper/gobra/frontend/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3783,7 +3783,7 @@ object Desugar extends LazyLogging {

case t: Type.StructT =>
val inFields: Vector[in.Field] = structD(t, addrMod)(src)
registerType(in.StructT(inFields, addrMod))
registerType(in.StructT(inFields, ghost = t.isGhost, addrMod))

case t: Type.AdtT =>
val adtName = nm.adt(t.declaredType)
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/viper/gobra/frontend/info/base/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ object Type {

case class StructEmbeddedT(typ: Type, isGhost: Boolean) extends StructClauseT

case class StructT(clauses: ListMap[String, StructClauseT], decl: PStructType, context: ExternalTypeInfo) extends ContextualType {
case class StructT(clauses: ListMap[String, StructClauseT], isGhost: Boolean, decl: PStructType, context: ExternalTypeInfo) extends ContextualType {
lazy val fieldsAndEmbedded: ListMap[String, Type] = clauses.map(extractTyp)
lazy val fields: ListMap[String, Type] = clauses.filter(isField).map(extractTyp)
lazy val embedded: ListMap[String, Type] = clauses.filterNot(isField).map(extractTyp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ trait TypeIdentity extends BaseProperty { this: TypeInfoImpl =>
case (l: DomainT, r: DomainT) => l == r
case (l: AdtT, r: AdtT) => l == r

case (StructT(clausesL, _, contextL), StructT(clausesR, _, contextR)) =>
contextL == contextR && clausesL.size == clausesR.size && clausesL.zip(clausesR).forall {
case ((lId, lc), (rId, rc)) => lId == rId && identicalTypes(lc.typ, rc.typ) && ((lc, rc) match {
case (StructT(clausesL, isGhostL, _, contextL), StructT(clausesR, isGhostR, _, contextR)) =>
isGhostL == isGhostR && contextL == contextR && clausesL.size == clausesR.size && clausesL.zip(clausesR).forall {
case ((lId, lc), (rId, rc)) => lId == rId && lc.isGhost == rc.isGhost && identicalTypes(lc.typ, rc.typ) && ((lc, rc) match {
case (_: StructFieldT, _: StructFieldT) => true
case (_: StructEmbeddedT, _: StructEmbeddedT) => true
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ trait TypeTyping extends BaseTyping { this: TypeInfoImpl =>

case n@ PMapType(key, elem) => isType(key).out ++ isType(elem).out ++
error(n, s"map key $key is not comparable", !comparableType(typeSymbType(key))) ++
error(n, s"map key $key cannot contain ghost fields", isStructTypeWithGhostFields(typeSymbType(key)))
error(n, s"map key $key can neither be a ghost struct nor contain ghost fields", isStructTypeWithGhostFields(typeSymbType(key)))

case t: PStructType =>
t.embedded.flatMap(e => isNotPointerTypePE.errors(e.typ)(e)) ++
Expand Down Expand Up @@ -191,7 +191,7 @@ trait TypeTyping extends BaseTyping { this: TypeInfoImpl =>
case (prev, x: PEmbeddedDecl) => prev ++ makeEmbedded(x, isGhost = isStructTypeGhost)
case (prev, PExplicitGhostStructClause(x: PEmbeddedDecl)) => prev ++ makeEmbedded(x, isGhost = true)
}
StructT(clauses, t, this)
StructT(clauses, isGhost = isStructTypeGhost, t, this)
}

/**
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/viper/gobra/translator/Names.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ object Names {
// we use a dollar sign to mark the beginning and end of the type list to avoid that `Tuple(Tuple(X), Y)` and `Tuple(Tuple(X, Y))` map to the same name:
case in.TupleT(ts, addr) => s"Tuple$$${ts.map(serializeType).mkString("")}$$${serializeAddressability(addr)}"
case in.PredT(ts, addr) => s"Pred$$${ts.map(serializeType).mkString("")}$$${serializeAddressability(addr)}"
case in.StructT(fields, addr) => s"Struct${serializeFields(fields)}${serializeAddressability(addr)}"
case in.StructT(fields, _, addr) => s"Struct${serializeFields(fields)}${serializeAddressability(addr)}"
case in.FunctionT(args, res, addr) => s"Func$$${args.map(serializeType).mkString("")}$$${res.map(serializeType).mkString("")}$$${serializeAddressability(addr)}"
case in.InterfaceT(name, addr) => s"Interface$name${serializeAddressability(addr)}"
case in.ChannelT(elemT, addr) => s"Channel${serializeType(elemT)}${serializeAddressability(addr)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,18 @@ class StructEncoding extends TypeEncoding {
*
* [(lhs: Struct{F}) == rhs: Struct{_}] -> AND f in actual(F): [lhs.f == rhs.f] (NOTE: f ranges only over actual fields since `goEqual` corresponds to actual comparison)
*/
override def goEqual(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = default(super.goEqual(ctx))(structEqual(ctx, useGoEquality = true))
override def goEqual(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = default(super.goEqual(ctx)) {
case (lhs :: ctx.Struct(_), rhs :: ctx.Struct(_), src) =>
def hasGhostStructType(e: in.Expr): Boolean = underlyingType(e.typ)(ctx) match {
case t: in.StructT => t.ghost
case _ => false
}

// if lhs is a ghost struct (_not_ an actual struct contains just ghost fields) then we use ghost equality
val isGhostStructComparison = hasGhostStructType(lhs)
require(isGhostStructComparison == hasGhostStructType(rhs), "type-system should enforce that the struct types involved in a comparison agree on their ghostness")
structEqual(ctx, useGoEquality = !isGhostStructComparison)(lhs, rhs, src)
}

/**
* Encodes equality of two struct values under consideration of either the Go or Gobra/ghost semantics
Expand Down Expand Up @@ -287,7 +298,7 @@ class StructEncoding extends TypeEncoding {
*/
private val shDfltFunc: FunctionGenerator[Vector[in.Field]] = new FunctionGenerator[Vector[in.Field]] {
override def genFunction(fs: Vector[in.Field])(ctx: Context): vpr.Function = {
val resType = in.StructT(fs, Shared)
val resType = in.StructT(fs, ghost = false, Shared) // ghostness does not matter as the resulting Viper type is the same
val vResType = typ(ctx)(resType)
val src = in.DfltVal(resType)(Source.Parser.Internal)
// variable name does not matter because it is turned into a vpr.Result
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Any copyright is dedicated to the Public Domain.
// http://creativecommons.org/publicdomain/zero/1.0/

package GhostStructComparison01

// a ghost struct cannot be compared to an actual struct

ghost type GhostStruct struct {
f int
}

type ActualStruct struct {
ghost f int
}

func foo(x ActualStruct, ghost y GhostStruct) {
//:: ExpectedOutput(type_error)
assert x == y
//:: ExpectedOutput(type_error)
assert x === y
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Any copyright is dedicated to the Public Domain.
// http://creativecommons.org/publicdomain/zero/1.0/

package GhostStructComparison01

// when comparing ghost structs, ghost equality should be used. I.e., `==` behaves like `===`.

ghost type GhostStruct struct {
f int
}

func test1(ghost x, y GhostStruct) {
//:: ExpectedOutput(assert_error:assertion_error)
assert x == y
}

func test2(ghost x, y GhostStruct) {
//:: ExpectedOutput(assert_error:assertion_error)
assert x === y
}
29 changes: 29 additions & 0 deletions src/test/resources/regressions/features/maps/maps-fail3.gobra
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Any copyright is dedicated to the Public Domain.
// http://creativecommons.org/publicdomain/zero/1.0/

package mapsfail3

// tests that neither structs with ghost fields, ghost structs, nor other ghost types can serve as a map's key

type ActualStruct struct {
ghost f int
}

ghost type GhostStruct struct {
f int
}

//:: ExpectedOutput(type_error)
func test1(m map[ActualStruct]int) int {
return 42
}

//:: ExpectedOutput(type_error)
func test2(m map[GhostStruct]int) int {
return 42
}

//:: ExpectedOutput(type_error)
func test3(m map[seq[int]]int) int {
return 42
}

0 comments on commit 0e6da08

Please sign in to comment.