Skip to content

Commit

Permalink
Bytecode: Avoid casts to supertype (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstepanov authored Dec 5, 2024
1 parent a784414 commit 355e935
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
import io.micronaut.inject.ast.ClassElement;
import io.micronaut.sourcegen.bytecode.MethodContext;
import io.micronaut.sourcegen.bytecode.TypeUtils;
import io.micronaut.sourcegen.model.ClassDef;
import io.micronaut.sourcegen.model.ClassTypeDef;
import io.micronaut.sourcegen.model.EnumDef;
import io.micronaut.sourcegen.model.ExpressionDef;
import io.micronaut.sourcegen.model.ObjectDef;
import io.micronaut.sourcegen.model.RecordDef;
import io.micronaut.sourcegen.model.TypeDef;
import org.objectweb.asm.commons.GeneratorAdapter;

Expand Down Expand Up @@ -61,34 +64,71 @@ private static void cast(GeneratorAdapter generatorAdapter, MethodContext contex
if (!from.isPrimitive() && to.isPrimitive()) {
unbox(generatorAdapter, context, to);
}
} else if (!from.makeNullable().equals(to.makeNullable())) {
if (from instanceof ClassTypeDef.ClassElementType fromElement) {
ClassElement fromClassElement = fromElement.classElement();
if (to instanceof ClassTypeDef.ClassElementType toElement) {
if (!fromClassElement.isAssignable(toElement.classElement())) {
checkCast(generatorAdapter, context, from, to);
}
} else if (to instanceof ClassTypeDef.JavaClass toClass) {
if (!fromClassElement.isAssignable(toClass.type())) {
checkCast(generatorAdapter, context, from, to);
}
} else if (to instanceof ClassTypeDef.ClassName toClassName) {
if (!fromClassElement.isAssignable(toClassName.className())) {
checkCast(generatorAdapter, context, from, to);
}
} else {
checkCast(generatorAdapter, context, from, to);
}
} else if (from instanceof ClassTypeDef.JavaClass fromClass && to instanceof ClassTypeDef.JavaClass toClass) {
if (!toClass.type().isAssignableFrom(fromClass.type())) {
checkCast(generatorAdapter, context, from, to);
}
} else {
checkCast(generatorAdapter, context, from, to);
} else if (needsCast(from, to)) {
checkCast(generatorAdapter, context, from, to);
}
}

private static boolean needsCast(TypeDef from, TypeDef to) {
if (from.makeNullable().equals(to.makeNullable())) {
return false;
}
if (from instanceof ClassTypeDef.Parameterized parameterized) {
return needsCast(parameterized.rawType(), to);
}
if (to instanceof ClassTypeDef.Parameterized parameterized) {
return needsCast(from, parameterized.rawType());
}
if (from instanceof ClassTypeDef.ClassElementType fromElement) {
return needsCast(fromElement.classElement(), to);
}
if (from instanceof ClassTypeDef.JavaClass fromClass) {
if (to instanceof ClassTypeDef.JavaClass toClass) {
return !toClass.type().isAssignableFrom(fromClass.type());
}
}
if (from instanceof ClassTypeDef.ClassDefType fromClassDef) {
ClassTypeDef fromSuperclass = getSuperclass(fromClassDef.objectDef());
if (fromSuperclass != null) {
return needsCast(fromSuperclass, to);
}
}
return true;
}

private static boolean needsCast(ClassElement from, TypeDef to) {
if (to instanceof ClassTypeDef.ClassElementType toElement) {
return !from.isAssignable(toElement.classElement());
}
if (to instanceof ClassTypeDef.JavaClass toClass) {
return !from.isAssignable(toClass.type());
}
if (to instanceof ClassTypeDef.ClassName toClassName) {
return !from.isAssignable(toClassName.name());
}
if (to instanceof ClassTypeDef.ClassDefType toClassDefType) {
if (from.isAssignable(toClassDefType.getName())) {
return false;
}
return !from.isAssignable(toClassDefType.getName());
}
return true;
}

private static ClassTypeDef getSuperclass(ObjectDef objectDef) {
if (objectDef instanceof ClassDef classDef) {
return classDef.getSuperclass();
}
if (objectDef instanceof EnumDef) {
return ClassTypeDef.of(Enum.class);
}
if (objectDef instanceof RecordDef) {
return ClassTypeDef.of(Record.class);
}
return null;
}


private static void checkCast(GeneratorAdapter generatorAdapter, MethodContext context, TypeDef from, TypeDef to) {
TypeDef toType = ObjectDef.getContextualType(context.objectDef(), to);
if (!toType.makeNullable().equals(from.makeNullable())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.AbstractList;
import java.util.List;
import java.util.Map;

import static io.micronaut.sourcegen.bytecode.DecompilerUtils.decompileToJava;
Expand Down Expand Up @@ -1519,6 +1521,169 @@ Object invoke() {
""", decompileToJava(bytes));
}

@Test
void testCastingClassDefWithSuperclass() {

ClassDef myList = ClassDef.builder("example.MyList")
.superclass(ClassTypeDef.of(AbstractList.class))
.build();

ClassDef classDef = ClassDef.builder("example.Test")
.addMethod(MethodDef.builder("load")
.returns(ClassTypeDef.of(List.class))
.build((aThis, methodParameters) -> myList.asTypeDef()
.instantiate().returning())
)
.build();

StringWriter bytecodeWriter = new StringWriter();
byte[] bytes = generateFile(classDef, bytecodeWriter);

String bytecode = bytecodeWriter.toString();

Assertions.assertEquals("""
// class version 61.0 (61)
// access flags 0x0
// signature Ljava/lang/Object;
// declaration: example/Test
class example/Test {
// access flags 0x0
<init>()V
ALOAD 0
INVOKESPECIAL java/lang/Object.<init> ()V
RETURN
// access flags 0x0
load()Ljava/util/List;
NEW example/MyList
DUP
INVOKESPECIAL example/MyList.<init> ()V
ARETURN
}
""", bytecode);

Assertions.assertEquals("""
package example;
import java.util.List;
class Test {
List load() {
return new MyList();
}
}
""", decompileToJava(bytes));
}

@Test
void testCastingThisClassDefWithSuperclass() {

ClassDef classDef = ClassDef.builder("example.Test")
.superclass(TypeDef.parameterized(AbstractList.class, Number.class))
.addMethod(MethodDef.builder("load")
.returns(ClassTypeDef.of(List.class))
.build((aThis, methodParameters) -> aThis.type().instantiate().returning())
)
.build();

StringWriter bytecodeWriter = new StringWriter();
byte[] bytes = generateFile(classDef, bytecodeWriter);

String bytecode = bytecodeWriter.toString();

Assertions.assertEquals("""
// class version 61.0 (61)
// access flags 0x0
// signature Ljava/util/AbstractList<Ljava/lang/Number;>;
// declaration: example/Test extends java.util.AbstractList<java.lang.Number>
class example/Test extends java/util/AbstractList {
// access flags 0x0
<init>()V
ALOAD 0
INVOKESPECIAL java/util/AbstractList.<init> ()V
RETURN
// access flags 0x0
load()Ljava/util/List;
NEW example/Test
DUP
INVOKESPECIAL example/Test.<init> ()V
ARETURN
}
""", bytecode);

Assertions.assertEquals("""
package example;
import java.util.AbstractList;
import java.util.List;
class Test extends AbstractList {
List load() {
return new Test();
}
}
""", decompileToJava(bytes));
}

@Test
void testCastingEnum() {

EnumDef enumDef = EnumDef.builder("example.MyEnum")
.addEnumConstant("A")
.addEnumConstant("B")
.build();

ClassDef classDef = ClassDef.builder("example.Test")
.addMethod(MethodDef.builder("load")
.returns(Enum.class)
.build((aThis, methodParameters) -> enumDef.asTypeDef()
.getStaticField(enumDef.getField("A"))
.returning())
)
.build();

StringWriter bytecodeWriter = new StringWriter();
byte[] bytes = generateFile(classDef, bytecodeWriter);

String bytecode = bytecodeWriter.toString();

Assertions.assertEquals("""
// class version 61.0 (61)
// access flags 0x0
// signature Ljava/lang/Object;
// declaration: example/Test
class example/Test {
// access flags 0x0
<init>()V
ALOAD 0
INVOKESPECIAL java/lang/Object.<init> ()V
RETURN
// access flags 0x0
load()Ljava/lang/Enum;
GETSTATIC example/MyEnum.A : Lexample/MyEnum;
ARETURN
}
""", bytecode);

Assertions.assertEquals("""
package example;
class Test {
Enum load() {
return MyEnum.A;
}
}
""", decompileToJava(bytes));
}

private String toBytecode(ObjectDef objectDef) {
StringWriter stringWriter = new StringWriter();
generateFile(objectDef, stringWriter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public final class ClassDef extends ObjectDef {
private final ClassTypeDef superclass;
private final StatementDef staticInitializer;

private ClassDef(ClassTypeDef type,
private ClassDef(ClassTypeDef.ClassName className,
EnumSet<Modifier> modifiers,
List<FieldDef> fields,
List<MethodDef> methods,
Expand All @@ -52,16 +52,17 @@ private ClassDef(ClassTypeDef type,
ClassTypeDef superclass,
List<ObjectDef> innerTypes,
StatementDef staticInitializer) {
super(type, modifiers, annotations, javadoc, methods, properties, superinterfaces, innerTypes);
super(className, modifiers, annotations, javadoc, methods, properties, superinterfaces, innerTypes);
ClassTypeDef.of(this);
this.fields = fields;
this.typeVariables = typeVariables;
this.superclass = superclass;
this.staticInitializer = staticInitializer;
}

@Override
public ClassDef withType(ClassTypeDef type) {
return new ClassDef(type, modifiers, fields, methods, properties, annotations, javadoc, typeVariables, superinterfaces, superclass, innerTypes, staticInitializer);
public ClassDef withClassName(ClassTypeDef.ClassName className) {
return new ClassDef(className, modifiers, fields, methods, properties, annotations, javadoc, typeVariables, superinterfaces, superclass, innerTypes, staticInitializer);
}

@Override
Expand Down Expand Up @@ -108,7 +109,7 @@ public FieldDef findField(String name) {
public FieldDef getField(String name) {
FieldDef field = findField(name);
if (field == null) {
throw new IllegalStateException("Class: " + this.name + " doesn't have a field: " + name);
throw new IllegalStateException("Class: " + this.className + " doesn't have a field: " + name);
}
return null;
}
Expand All @@ -122,8 +123,11 @@ public boolean hasField(String name) {
if (superclass instanceof ClassTypeDef.ClassElementType classElementType) {
return classElementType.classElement().findField(name).isPresent();
}
if (superclass instanceof ClassTypeDef.ClassDefType classDefType) {
return classDefType.classDef().hasField(name);
if (superclass instanceof ClassTypeDef.ClassDefType classDefType && classDefType.objectDef() instanceof ClassDef classDef) {
return classDef.hasField(name);
}
if (superclass instanceof ClassTypeDef.ClassDefType classDefType && classDefType.objectDef() instanceof EnumDef enumDef) {
return enumDef.hasField(name);
}
if (superclass instanceof ClassTypeDef.JavaClass javaClass) {
try {
Expand All @@ -146,7 +150,7 @@ public StatementDef getStaticInitializer() {

@Override
public String toString() {
return "ClassDef{" + "name='" + name + '\'' + '}';
return "ClassDef{" + "name='" + className + '\'' + '}';
}

/**
Expand Down Expand Up @@ -199,7 +203,7 @@ public ClassDefBuilder addStaticInitializer(StatementDef staticInitializer) {
}

public ClassDef build() {
return new ClassDef(ClassTypeDef.of(name), modifiers, fields, methods, properties, annotations, javadoc, typeVariables, superinterfaces, superclass, innerTypes, staticInitializer);
return new ClassDef(new ClassTypeDef.ClassName(name), modifiers, fields, methods, properties, annotations, javadoc, typeVariables, superinterfaces, superclass, innerTypes, staticInitializer);
}

/**
Expand Down
Loading

0 comments on commit 355e935

Please sign in to comment.