Skip to content

Commit

Permalink
Fix a minor issue with leadingBlockDimensions selection.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Aug 15, 2024
1 parent 2a0f77d commit 4352e65
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 10 deletions.
2 changes: 1 addition & 1 deletion bin/nnc/laplacian_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ int main(int argc, char** argv)
for (int j = 0; j < sizeof(transposeStates) / (sizeof(bool) * 2); j++)
{
TestDescriptor testDescriptor = TestDescriptor();
testDescriptor.precision = GEMMOperandPrecision::FP16;
testDescriptor.precision = GEMMOperandPrecision::FP32;
testDescriptor.problemSize = problemSize;
testDescriptor.transposeState[0] = transposeStates[j * 2];
testDescriptor.transposeState[1] = transposeStates[j * 2 + 1];
Expand Down
6 changes: 3 additions & 3 deletions lib/nnc/mfa/v2/GEMMDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ std::pair<GEMMKernelDescriptor, PipelineValue<GEMMKernel> *> GEMMDescriptor::fin
constants->setConstantValue(&N, MTL::DataTypeUInt, 1);
constants->setConstantValue(&K, MTL::DataTypeUInt, 2);

auto chooseLeadingDimension =
[=](unsigned int specifiedLeading, bool transposeState, unsigned int untransposedRows, unsigned int untransposedColumns) -> unsigned int {
auto chooseLeadingDimension =
[=](unsigned int specifiedLeading, bool transposeState, unsigned int untransposedRows, unsigned int untransposedColumns) -> unsigned int {
unsigned int expectedLeading;
if (transposeState) {
expectedLeading = untransposedRows;
Expand Down Expand Up @@ -191,7 +191,7 @@ std::pair<GEMMKernelDescriptor, PipelineValue<GEMMKernel> *> GEMMDescriptor::fin
// the cache takes ownership, and the pointer doesn't become a zombie
// object.
PipelineValue<GEMMKernel>* output = new PipelineValue<GEMMKernel> { kernel, pipeline };
return std::make_pair(kernelDesc, output);
return std::make_pair(kernelDesc, output);
} else {
auto kernelDesc = GEMMKernelDescriptor(blockDimensionsAndPaddedBlockDimensions.first, this->memoryPrecisions, blockDimensionsAndPaddedBlockDimensions.second, preferAsyncLoad, preferAsyncStore.value(), registerPrecisions, splits, this->transposeState, this->useBias);
struct Candidate {
Expand Down
12 changes: 10 additions & 2 deletions lib/nnc/mfa/v2/GEMMKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ GEMMKernel::GEMMKernel(GEMMKernelDescriptor descriptor, MTL::Device *const devic
transposeState = descriptor.transposeState;
preferAsyncLoad = descriptor.preferAsyncLoad;
preferAsyncStore = descriptor.preferAsyncStore;
auto useBias = descriptor.useBias;
useBias = descriptor.useBias;
threadgroupSize = 32 * splits[0] * splits[1];

// Validate the correctness of register precisions.
Expand Down Expand Up @@ -248,6 +248,7 @@ std::string GEMMKernel::createSource() const noexcept {
source.SetValue("MEMORY_NAME_A", memoryName('A'));
source.SetValue("MEMORY_NAME_B", memoryName('B'));
source.SetValue("MEMORY_NAME_C", memoryName('C'));
source.SetValue("MEMORY_NAME_BIAS", memoryName('S'));
source.SetValue("REGISTER_NAME_A", registerName('A'));
source.SetValue("REGISTER_NAME_B", registerName('B'));
source.SetValue("REGISTER_NAME_C", registerName('C'));
Expand Down Expand Up @@ -283,6 +284,13 @@ std::string GEMMKernel::createSource() const noexcept {
kernel void gemm(device {{MEMORY_NAME_A}} *A [[buffer(0)]],
device {{MEMORY_NAME_B}} *B [[buffer(1)]],
device {{MEMORY_NAME_C}} *C [[buffer(2)]],
)";
if (useBias) {
source += R"(
device {{MEMORY_NAME_BIAS}} *bias [[buffer(3)]],
)";
}
source += R"(
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
uint3 gid [[threadgroup_position_in_grid]],
Expand Down Expand Up @@ -521,7 +529,7 @@ void GEMMKernel::createLoadC(CodeWriter *source) const noexcept {
// In the vanilla GEMM kernel, the extra storing code can be optimized
// away at compile time. The compiler may allocate less registers, and
// occupancy may be greater.
std::string output = "(M >= M_group) && (N >= N_group)";
std::string output = "(M >= M_group) && (N >= N_group)";

// When accumulate is supported, there are overlapping writes. We must
// sanitize the matrix edge with async copy. The optimization from
Expand Down
2 changes: 2 additions & 0 deletions lib/nnc/mfa/v2/GEMMKernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ struct GEMMKernel {

bool preferAsyncStore;

bool useBias;

uint16_t registerM;

uint16_t registerN;
Expand Down
8 changes: 4 additions & 4 deletions lib/nnc/mfa/v2/GEMMKernelDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ std::pair<simd::ushort3, std::optional<simd::ushort3>> GEMMKernelDescriptor::get
// - (memA, memB, memC) = (FP16, FP16, FP32)
// - (memA, memB, memC) = (FP16, FP32, FP32)
// - (memA, memB, memC) = (FP16, FP32, FP16)
if (transposeState[0] == false && transposeState[1] == false) {
return std::make_pair(blockDimensions, simd::ushort3 { 24, 24, 48 });
} else if (transposeState[0] == false && transposeState[1] == true) {
if (!transposeState[0] && !transposeState[1]) {
return std::make_pair(blockDimensions, simd::ushort3 { 24, 48, 48 });
} else if (!transposeState[0] && transposeState[1]) {
if (memoryPrecisions.B == GEMMOperandPrecision::FP32) {
return std::make_pair(blockDimensions, simd::ushort3 { 24, 28, 48 });
} else {
return std::make_pair(blockDimensions, simd::ushort3 { 24, 24, 48 });
}
} else if (transposeState[0] == true && transposeState[1] == false) {
} else if (transposeState[0] && !transposeState[1]) {
if (memoryPrecisions.A == GEMMOperandPrecision::FP32) {
return std::make_pair(blockDimensions, simd::ushort3 { 52, 48, 48 });
} else {
Expand Down

0 comments on commit 4352e65

Please sign in to comment.