Skip to content

Commit

Permalink
Merge pull request #28 from KacperFKorban/generics
Browse files Browse the repository at this point in the history
Generics
  • Loading branch information
KacperFKorban authored Mar 20, 2024
2 parents 8e1346e + f41e91d commit 742883d
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 23 deletions.
85 changes: 64 additions & 21 deletions guinep/src/main/scala/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ private[guinep] object macros {
case _ => None
}

def wrongParamsListError(f: Expr[Any]): Nothing =
private def wrongParamsListError(f: Expr[Any]): Nothing =
report.errorAndAbort(s"Wrong params list, expected a function reference, got: ${f.show}", f.asTerm.pos)

private def unsupportedFunctionParamType(t: TypeRepr, pos: Option[Position] = None): Nothing = pos match {
Expand All @@ -35,6 +35,16 @@ private[guinep] object macros {
private def select(s: String): Term =
t.select(t.tpe.typeSymbol.methodMember(s).head)

extension (s: Symbol)
private def prettyName: String =
s.name.stripSuffix("$")

extension (tpe: TypeRepr)
private def stripAnnots: TypeRepr = tpe match {
case AnnotatedType(tpe, _) => tpe.stripAnnots
case _ => tpe
}

private def functionNameImpl(f: Expr[Any]): Expr[String] = {
val name = f.asTerm match {
case Inlined(_, _, Lambda(_, body)) =>
Expand Down Expand Up @@ -83,26 +93,28 @@ private[guinep] object macros {
val isEnumCaseNonClassDef = typeSymbol.flags.is(Flags.Enum) && typeSymbol.flags.is(Flags.Case) && !typeSymbol.isClassDef
isModule || isEnumCaseNonClassDef

private def tpeArguments(tpe: TypeRepr): List[TypeRepr] = tpe match {
case AppliedType(tpe, args) => args
case _ => Nil
}

private def functionFormElementFromTree(paramName: String, paramType: TypeRepr): FormElement = paramType match {
case ntpe: NamedType if ntpe.name == "String" => FormElement.TextInput(paramName)
case ntpe: NamedType if ntpe.name == "Int" => FormElement.NumberInput(paramName)
case ntpe: NamedType if ntpe.name == "Boolean" => FormElement.CheckboxInput(paramName)
case ntpe if isProductTpe(ntpe) =>
val classSymbol = ntpe.typeSymbol
val typeDefParams = classSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTypeParam)
val fields = classSymbol.primaryConstructor.paramSymss.flatten.filter(_.isValDef).map(_.tree).collect { case v: ValDef => v }
FormElement.FieldSet(paramName, fields.map(v => functionFormElementFromTree(v.name, v.tpt.tpe)))
FormElement.FieldSet(
paramName,
fields.map { valdef =>
functionFormElementFromTree(
valdef.name,
valdef.tpt.tpe.substituteTypes(typeDefParams, ntpe.typeArgs).stripAnnots
)
}
)
case ntpe if isSumTpe(ntpe) =>
val classSymbol = ntpe.typeSymbol
val typeParamSyms = classSymbol.primaryConstructor.paramSymss.flatten.filter(_.isType)
val tpeArgs = tpeArguments(ntpe)
val childrenAppliedTpes = classSymbol.children.map(_.typeRef)
val childrenAppliedTpes = classSymbol.children.map(child => appliedChild(child, classSymbol, ntpe.typeArgs)).map(_.stripAnnots)
val childrenFormElements = childrenAppliedTpes.map(t => functionFormElementFromTree("value", t))
val options = classSymbol.children.map(_.name).zip(childrenFormElements)
val options = classSymbol.children.map(_.prettyName).zip(childrenFormElements)
FormElement.Dropdown(paramName, options)
case _ =>
unsupportedFunctionParamType(paramType)
Expand All @@ -113,37 +125,68 @@ private[guinep] object macros {
functionParams(f).map { case ValDef(name, tpt, _) => functionFormElementFromTree(name, tpt.tpe) } .map(Expr(_))
)

private def appliedChild(childSym: Symbol, parentSym: Symbol, parentArgs: List[TypeRepr]): TypeRepr = childSym.tree match {
case classDef @ ClassDef(_, _, parents, _, _) =>
parents
.collect {
case tpt: TypeTree => tpt.tpe
}
.collectFirst {
case AppliedType(tpe, args) if tpe.typeSymbol == parentSym => args
case tpe if tpe.typeSymbol == parentSym => Nil
}.match
case None =>
report.errorAndAbort(s"""PANIC: Could not find applied parent for ${childSym.name}, parents: ${parents.map(_.show).mkString(",")}""", classDef.pos)
case Some(parentExtendsArgs) =>
val childDefArgs = classDef.symbol.primaryConstructor.paramSymss.flatten.filter(_.isTypeParam).map(_.typeRef)
val childArgTpes = childDefArgs.map { arg =>
arg.substituteTypes(parentExtendsArgs.map(_.typeSymbol), parentArgs)
}
// TODO(kπ) might want to handle the case when there are unsubstituted type parameters left
val childTpe = childSym.typeRef.appliedTo(childArgTpes)
childTpe
case _ =>
childSym.typeRef
}

private def constructArg(paramTpe: TypeRepr, param: Term): Term = {
paramTpe match {
case ntpe: NamedType if ntpe.name == "String" => param.select("asInstanceOf").appliedToType(ntpe)
case ntpe: NamedType if ntpe.name == "Int" => param.select("asInstanceOf").appliedToType(ntpe)
case ntpe: NamedType if ntpe.name == "Boolean" => param.select("asInstanceOf").appliedToType(ntpe)
case ntpe if isCaseObjectTpe(ntpe) && ntpe.typeSymbol.flags.is(Flags.Module) =>
Ref(ntpe.typeSymbol.companionModule)
case ntpe if isCaseObjectTpe(ntpe) =>
Ident(ntpe.typeSymbol.termRef)
Ref(ntpe.typeSymbol)
case ntpe if isProductTpe(ntpe) =>
val classSymbol = ntpe.classSymbol.getOrElse(unsupportedFunctionParamType(paramTpe, Some(param.pos)))
val classSymbol = ntpe.typeSymbol
val typeDefParams = classSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTypeParam)
val fields = classSymbol.primaryConstructor.paramSymss.flatten.filter(_.isValDef).map(_.tree)
val paramValue = '{ ${param.asExpr}.asInstanceOf[Map[String, Any]] }.asTerm
val args = fields.collect { case field: ValDef =>
val fieldName = field.asInstanceOf[ValDef].name
val fieldName = field.name
val fieldValue = paramValue.select("apply").appliedTo(Literal(StringConstant(fieldName)))
constructArg(field.tpt.tpe, fieldValue)
constructArg(
field.tpt.tpe.substituteTypes(typeDefParams, ntpe.typeArgs),
fieldValue
)
}
New(Inferred(ntpe)).select(classSymbol.primaryConstructor).appliedToArgs(args)
New(Inferred(ntpe.typeSymbol.typeRef)).select(classSymbol.primaryConstructor).appliedToTypes(ntpe.typeArgs).appliedToArgs(args)
case ntpe if isSumTpe(ntpe) =>
val classSymbol = ntpe.classSymbol.getOrElse(unsupportedFunctionParamType(paramTpe, Some(param.pos)))
val classSymbol = ntpe.typeSymbol
val className = classSymbol.name
val children = classSymbol.children
val childrenAppliedTpes = children.map(child => appliedChild(child, classSymbol, ntpe.typeArgs)).map(_.stripAnnots)
val paramMap = '{ ${param.asExpr}.asInstanceOf[Map[String, Any]] }.asTerm
val paramName = paramMap.select("apply").appliedTo(Literal(StringConstant("name")))
val paramValue = paramMap.select("apply").appliedTo(Literal(StringConstant("value")))
children.foldRight[Term]{
children.zip(childrenAppliedTpes).foldRight[Term]{
'{ throw new RuntimeException(s"Class ${${paramName.asExpr}} is not a child of ${${Expr(className)}}") }.asTerm
} { (child, acc) =>
val childName = Literal(StringConstant(child.name))
} { case ((child, childAppliedTpe), acc) =>
val childName = Literal(StringConstant(child.prettyName))
If(
paramName.select("equals").appliedTo(childName),
constructArg(child.typeRef, paramValue),
constructArg(childAppliedTpe, paramValue),
acc
)
}
Expand Down
20 changes: 18 additions & 2 deletions testcases/src/main/scala/main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ def roll20: Int =
def roll6(): Int =
scala.util.Random.nextInt(6) + 1

sealed trait WeirdGADT[+A]
case class IntValue(value: Int) extends WeirdGADT[Int]
case class SomeValue[+A](value: A) extends WeirdGADT[A]
case class SomeOtherValue[+A, +B](value: A, value2: B) extends WeirdGADT[A]

// This fails on unknown type params
def printsWeirdGADT(g: WeirdGADT[String]): String = g match
case SomeValue(value) => s"SomeValue($value)"
case SomeOtherValue(value, value2) => s"SomeOtherValue($value, $value2)"

// This loops forever
def concatAll(elems: List[String]): String =
elems.mkString

@main
def run: Unit =
guinep.web(
Expand All @@ -66,10 +80,12 @@ def run: Unit =
concat,
giveALongText,
addObj,
// greetMaybeName,
greetMaybeName,
greetInLanguage,
nameWithPossiblePrefix,
nameWithPossiblePrefix1,
roll20,
roll6()
roll6(),
// printsWeirdGADT
// concatAll
)

0 comments on commit 742883d

Please sign in to comment.