Skip to content

Commit

Permalink
Rework on the CStyle type check
Browse files Browse the repository at this point in the history
Instead of checking what is C-Style struct, we should check C-Style
type.

Proposal is not very clear about this. Will also update proposal.
  • Loading branch information
kaizhangNV committed Jan 23, 2025
1 parent 01b65c8 commit d508623
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 136 deletions.
176 changes: 60 additions & 116 deletions source/slang/slang-check-conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,142 +216,71 @@ ConstructorDecl* SemanticsVisitor::_getSynthesizedConstructor(
return nullptr;
}

static StructDecl* _getStructDecl(Type* type)
bool SemanticsVisitor::isCStyleType(Type* type)
{
if (as<VectorExpressionType>(type) || as<MatrixExpressionType>(type) ||
as<ArithmeticExpressionType>(type) || as<BuiltinType>(type))
auto cacheResult = [&](bool result)
{
return nullptr;
}

if (auto structDecl = isDeclRefTypeOf<StructDecl>(type))
return structDecl.getDecl();

return nullptr;
}

bool SemanticsVisitor::_cStyleStructBasicCheck(Decl* decl)
{
// 1. It has to be a user-defined struct type, or a basic scalar, vector or matrix type
StructDecl* structDecl = nullptr;
if (isFromCoreModule(decl))
return false;
getShared()->cacheCStyleType(type, result);
return result;
};

if (auto varDecl = as<VarDecl>(decl))
// Check cache first
if (bool* isCStyle = getShared()->isCStyleType(type))
{
auto type = varDecl->getType();
if (as<VectorExpressionType>(type) || as<MatrixExpressionType>(type) ||
as<BasicExpressionType>(type) || isDeclRefTypeOf<EnumDecl>(type).getDecl())
return true;

// check for user-defined struct type
structDecl = _getStructDecl(type);
if (!structDecl)
return false;
return *isCStyle;
}

if (!structDecl)
structDecl = as<StructDecl>(decl);
// 1. It has to be basic scalar, vector or matrix type, or user-defined struct.
if (as<VectorExpressionType>(type) || as<MatrixExpressionType>(type) ||
as<BasicExpressionType>(type) || isDeclRefTypeOf<EnumDecl>(type).getDecl())
return cacheResult(true);


// 2. It cannot have inheritance, but inherit from interface is fine.
for (auto inheritanceDecl : structDecl->getMembersOfType<InheritanceDecl>())
if (auto structDecl = isDeclRefTypeOf<StructDecl>(type).getDecl())
{
if (!isDeclRefTypeOf<InterfaceDecl>(inheritanceDecl->base.type))
// 2. It cannot have inheritance, but inherit from interface is fine.
for (auto inheritanceDecl : structDecl->getMembersOfType<InheritanceDecl>())
{
return false;
if (!isDeclRefTypeOf<InterfaceDecl>(inheritanceDecl->base.type))
{
return cacheResult(false);
}
}
}

// 3. It cannot have explicit constructor
if (_hasExplicitConstructor(structDecl, true))
return false;
// 3. It cannot have explicit constructor
if (_hasExplicitConstructor(structDecl, true))
return cacheResult(false);

// 4. All of its members have to have the same visibility as the struct itself.
DeclVisibility structVisibility = getDeclVisibility(structDecl);
for (auto varDecl : structDecl->getMembersOfType<VarDeclBase>())
{
if (getDeclVisibility(varDecl) != structVisibility)
// 4. All of its members have to have the same visibility as the struct itself.
DeclVisibility structVisibility = getDeclVisibility(structDecl);
for (auto varDecl : structDecl->getMembersOfType<VarDeclBase>())
{
return false;
if (getDeclVisibility(varDecl) != structVisibility)
{
return cacheResult(false);
}
}
}
return true;
}

bool SemanticsVisitor::isCStyleStruct(StructDecl* structDecl)
{
// Get the result from the cache first
if (bool* isCStyle = getShared()->isCStyleStruct(structDecl))
{
return *isCStyle;
}
for (auto varDecl : structDecl->getMembersOfType<VarDeclBase>())
{
Type* varType = varDecl->getType();

// rules 1-4 are checked in _cStyleStructBasicCheck for all the non-array members
if (!_cStyleStructBasicCheck(structDecl))
{
getShared()->cacheCStyleStruct(structDecl, false);
return false;
// Recursively check the type of the member.
if (!isCStyleType(varType))
return cacheResult(false);
}
}

// 5. All its members are legacy C-Style structs or arrays of legacy C-style structs
for (auto varDecl : structDecl->getMembersOfType<VarDeclBase>())
if (auto arrayType = as<ArrayExpressionType>(type))
{
// if the member is an array, check if the element is legacy C-style rule.
if (auto arrayType = as<ArrayExpressionType>(varDecl->getType()))
if (arrayType->isUnsized())
{
if (arrayType->isUnsized())
{
getShared()->cacheCStyleStruct(structDecl, false);
return false;
}
auto* elementType = arrayType->getElementType();
for (;;)
{
if (auto nextType = as<ArrayExpressionType>(elementType))
{
if (arrayType->isUnsized())
{
getShared()->cacheCStyleStruct(structDecl, false);
return false;
}
elementType = nextType->getElementType();
}
else
break;
}

if (auto elemStructDecl = _getStructDecl(elementType))
{
if (!_cStyleStructBasicCheck(elemStructDecl))
{
getShared()->cacheCStyleStruct(elemStructDecl, false);
return false;
}
}
else
{
// if the element is not a struct, it has to be a scalar, vector or matrix.
if (!as<VectorExpressionType>(elementType) &&
!as<MatrixExpressionType>(elementType) && !as<BasicExpressionType>(elementType))
{
getShared()->cacheCStyleStruct(structDecl, false);
return false;
}
}
}
else
{
// all the other members still go through the basic check.
if (!_cStyleStructBasicCheck(varDecl))
{
getShared()->cacheCStyleStruct(structDecl, false);
return false;
}
getShared()->cacheCStyleType(type, false);
return cacheResult(false);
}
}

getShared()->cacheCStyleStruct(structDecl, true);
return true;
return cacheResult(true);
}

Expr* SemanticsVisitor::_createCtorInvokeExpr(
Expand Down Expand Up @@ -385,7 +314,7 @@ bool SemanticsVisitor::createInvokeExprForExplicitCtor(
// back to legacy initializer list logic.
if (fromInitializerListExpr->m_synthesizedForTypeCastZero)
{
if (!isCStyleStruct(toStructDeclRef.getDecl()))
if (!isCStyleType(toType))
return false;
}

Expand All @@ -396,7 +325,22 @@ bool SemanticsVisitor::createInvokeExprForExplicitCtor(
fromInitializerListExpr->loc,
fromInitializerListExpr->args);

ctorInvokeExpr = CheckTerm(ctorInvokeExpr);
DiagnosticSink tempSink(getSourceManager(), nullptr);
SemanticsVisitor subVisitor(withSink(&tempSink));
ctorInvokeExpr = subVisitor.CheckTerm(ctorInvokeExpr);

if (tempSink.getErrorCount())
{
if (!isCStyleType(toType))
{
Slang::ComPtr<ISlangBlob> blob;
tempSink.getBlobIfNeeded(blob.writeRef());
getSink()->diagnoseRaw(
Severity::Error,
static_cast<char const*>(blob->getBufferPointer()));
}
return false;
}

if (outExpr)
{
Expand All @@ -418,7 +362,7 @@ bool SemanticsVisitor::createInvokeExprForSynthesizedCtor(
if (!structDecl || structDecl->m_synthesizedCtorMap.getCount() == 0)
return false;

bool isCStyle = isCStyleStruct(structDecl);
bool isCStyle = isCStyleType(toType);

// TODO: This is just a special case for a backwards-compatibility feature
// for HLSL, this flag will imply that the initializer list is synthesized
Expand Down
14 changes: 5 additions & 9 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -760,14 +760,11 @@ struct SharedSemanticsContext : public RefObject
m_mapTypePairToImplicitCastMethod[key] = candidate;
}

bool* isCStyleStruct(StructDecl* structDecl)
{
return m_isCStyleStructCache.tryGetValue(structDecl);
}
bool* isCStyleType(Type* type) { return m_isCStyleTypeCache.tryGetValue(type); }

void cacheCStyleStruct(StructDecl* structDecl, bool isCStyle)
void cacheCStyleType(Type* type, bool isCStyle)
{
m_isCStyleStructCache.addIfNotExists(structDecl, isCStyle);
m_isCStyleTypeCache.addIfNotExists(type, isCStyle);
}
// Get the inner most generic decl that a decl-ref is dependent on.
// For example, `Foo<T>` depends on the generic decl that defines `T`.
Expand Down Expand Up @@ -898,7 +895,7 @@ struct SharedSemanticsContext : public RefObject
Dictionary<DeclRef<Decl>, InheritanceInfo> m_mapDeclRefToInheritanceInfo;
Dictionary<TypePair, SubtypeWitness*> m_mapTypePairToSubtypeWitness;
Dictionary<ImplicitCastMethodKey, ImplicitCastMethod> m_mapTypePairToImplicitCastMethod;
Dictionary<StructDecl*, bool> m_isCStyleStructCache;
Dictionary<Type*, bool> m_isCStyleTypeCache;
};

/// Local/scoped state of the semantic-checking system
Expand Down Expand Up @@ -2814,8 +2811,7 @@ struct SemanticsVisitor : public SemanticsContext
ConstructorDecl* _getSynthesizedConstructor(
StructDecl* structDecl,
ConstructorDecl::ConstructorFlavor flavor);
bool isCStyleStruct(StructDecl* structDecl);
bool _cStyleStructBasicCheck(Decl* decl);
bool isCStyleType(Type* type);
};


Expand Down
4 changes: 2 additions & 2 deletions tests/cross-compile/cpp-resource.slang
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void computeMain(
uint tid = dispatchThreadID.x;

int2 fromScalar = tid.x;
uint2 another = {0}; // uint2 is not C-Style struct
uint2 another = {};

float2 loc = dispatchThreadID.xy * 0.5f;

Expand All @@ -50,4 +50,4 @@ void computeMain(
doSomething(int(tid), arr);

outputBuffer[tid] = int(tid * tid) + thing.a + thing3.a + int(v + s) + value + fromScalar.y + int(another.y) + int(m.x) + int(l) + int(arr[0].y); // + thing.a;
}
}
4 changes: 2 additions & 2 deletions tests/hlsl-intrinsic/matrix-double-reduced-intrinsic.slang
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

Float scalarF = idx * (1.0f / (4.0f));

FloatMatrix ft = {float2(0.0f), float2(0.0f)}; // matrix is not C-Style struct
FloatMatrix ft = {};

FloatMatrix f = { { scalarF + 0.01, scalarF + 0.02}, { scalarF + 0.011, scalarF + 0.022}};

Expand Down Expand Up @@ -77,4 +77,4 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
ft += clamp(f, makeFloatMatrix(0.1l), makeFloatMatrix(0.3l));

outputBuffer[idx] = calcTotal(ft);
}
}
2 changes: 1 addition & 1 deletion tests/hlsl-intrinsic/matrix-double.slang
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

Float scalarF = idx * (1.0f / (4.0f));

FloatMatrix ft = {float2(0.0f), float2(0.0f)}; // matrix is not C-Style struct
FloatMatrix ft = {};

FloatMatrix f = { { scalarF + 0.01, scalarF + 0.02}, { scalarF + 0.011, scalarF + 0.022}};

Expand Down
3 changes: 1 addition & 2 deletions tests/hlsl-intrinsic/matrix-float.slang
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)

float scalarF = idx * (1.0f / (4.0f));

// Note matrix<> is not C-Style struct, so we can't use {} to initialize it.
FloatMatrix ft = {float2(0.0f), float2(0.0f)};
FloatMatrix ft = {};

FloatMatrix f = { { scalarF + 0.01, scalarF + 0.02f}, { scalarF + 0.011f, scalarF + 0.022f}};

Expand Down
2 changes: 1 addition & 1 deletion tests/hlsl-intrinsic/matrix-int.slang
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
matrix<int, 3, 3> a = { { 0, 1, 2}, {2, 4, 6}, {16, 21, 32}};
matrix<int, 3, 3> b = { { 4, 9, -1}, {-2, 4, 2}, {31, -3, 7}};

matrix<int, 3, 3> t = {int3(0), int3(0), int3(0)}; // matrix is not C-Style struct
matrix<int, 3, 3> t = {};

t += max(a, b);
t += min(a, b);
Expand Down
4 changes: 2 additions & 2 deletions tests/hlsl-intrinsic/vector-double-reduced-intrinsic.slang
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
FloatVector f = FloatVector(0.1f, vf, vf + 0.2f);

// Operate over all values
FloatVector ft = {};
FloatVector ft = {0.0f};

// TODO(JS):
// fmod - disabled not available on D3D for double
Expand Down Expand Up @@ -71,4 +71,4 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
ft += clamp(f, 0.1l, 0.3l);

outputBuffer[idx] = vector<Float, 4>(ft, 0);
}
}
2 changes: 1 addition & 1 deletion tests/hlsl-intrinsic/vector-float.slang
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
FloatVector f = FloatVector(0.1f, vf, vf + 0.2f);

// Operate over all values
FloatVector ft = {0.0f}; // vector is not C-Style struct
FloatVector ft = {};

// fmod
ft += FloatVector(IntVector(((f % 0.11f) * 100) + 0.5));
Expand Down

0 comments on commit d508623

Please sign in to comment.