From 785cf57a163cd2eba5b9a4e198fda40cae2787df Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Thu, 15 Aug 2024 19:01:55 -0400 Subject: [PATCH] Use new GEMM throughout. --- lib/nnc/mfa/ccv_nnc_mfa.cpp | 4 + lib/nnc/mfa/ccv_nnc_mfa.hpp | 1 + lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp | 363 ++++++++++-------------------- lib/nnc/mfa/v2/GEMMDescriptor.cpp | 24 +- lib/nnc/mfa/v2/GEMMDescriptor.hpp | 2 + lib/nnc/mfa/v2/GEMMKernel.cpp | 20 ++ lib/nnc/mfa/v2/ShaderCache.hpp | 4 + lib/nnc/mps/ccv_nnc_mps.m | 1 + 8 files changed, 175 insertions(+), 244 deletions(-) diff --git a/lib/nnc/mfa/ccv_nnc_mfa.cpp b/lib/nnc/mfa/ccv_nnc_mfa.cpp index d357156e1..b50937e25 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.cpp @@ -10,6 +10,10 @@ mfa::context* ccv_nnc_init_mfa_context(MTL::Device* device) { return new mfa::context(device); } +void ccv_nnc_mfa_clear_pipeline_cache(ccv_nnc_mfa_context_t* context) { + context->v2_cache.evict(); +} + void ccv_nnc_deinit_mfa_context(mfa::context* context) { delete context; } diff --git a/lib/nnc/mfa/ccv_nnc_mfa.hpp b/lib/nnc/mfa/ccv_nnc_mfa.hpp index 9708b7fcf..e1d12fef6 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.hpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.hpp @@ -71,6 +71,7 @@ extern "C" { #endif // __cplusplus ccv_nnc_mfa_context_t* ccv_nnc_init_mfa_context(mtl_device_t* context); +void ccv_nnc_mfa_clear_pipeline_cache(ccv_nnc_mfa_context_t* context); void ccv_nnc_deinit_mfa_context(ccv_nnc_mfa_context_t* context); uint8_t ccv_nnc_mfa_context_supported(ccv_nnc_mfa_context_t* context); uint16_t ccv_nnc_mfa_context_log_level(ccv_nnc_mfa_context_t* context); diff --git a/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp b/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp index b14bf3398..465d09bf8 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp @@ -23,259 +23,142 @@ void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t pa num_tensors += 1; } CCV_NNC_MFA_PRECONDITION((num_tensors == 3) || (num_tensors == 4)) - - // Count the number of GEMMs at all. - // - // MFA | 39 | 60% | - // MPSMatrix | 5 | 8% | - // MPSGraph | 21 | 32% | - // Total | 65 | 100% | - ccv_nnc_mfa_log_message("\n"); - ccv_nnc_mfa_log_message("MFA\n"); - - // Count the percentage of GEMMs that match a specific set of criteria. - // - data_type = any - // - M = any - // - N = any - // - K = any - // - A_trans = any - // - B_trans = any - // - D_trans = any - // - alpha = 1.0 - // - beta = 0.0 - // - batched = false - // - fused_activation_function = false - // - fused_bias = false - // - // - batch_dims_a = any - // - batch_dims_b = any - // - batch_dims_d = any - // - // - num_tensors = 3 - // - // YES | 17 | 44% | <- works without any modification - // NO (1) | 0 | 0% | <- suspected unused features: alpha, beta, activation - // NO (2) | 17 | 13% | <- batching - // NO (3) | 5 | 44% | <- fused bias (D operand) - // Total | 39 | 100% | - // - // Supporting the remaining variants should require little effort. - // NO (1) - Scan the codebase to ensure these features are never used. Then, - // delete the corresponding members from 'mfa::hash'. - // NO (2) - Batching just requires changing the memory address of A/B/C. - // NO (3) - For fused bias, take note of the shift to the matrix origin (to - // keep memory reads in-bounds without using async copies). Make - // sure you apply this shift to the bias vector. - bool canEncodeNewGEMM = false; - if ((params.alpha == 1.0) && - (params.beta == 0.0) && - (params.batched == false) && - (params.fused_activation_function == false)) { - canEncodeNewGEMM = true; - ccv_nnc_mfa_log_message("\n"); - ccv_nnc_mfa_log_message("YES\n"); - } else { - if ((params.alpha != 1.0) || - (params.beta != 0.0) || - (params.fused_activation_function != false)) { - ccv_nnc_mfa_log_message("\n"); - ccv_nnc_mfa_log_message("NO (1)\n"); - } else if (params.batched != false) { - ccv_nnc_mfa_log_message("\n"); - ccv_nnc_mfa_log_message("NO (2)\n"); - } else { - ccv_nnc_mfa_log_message("\n"); - ccv_nnc_mfa_log_message("NO (3)\n"); - } - } - + // Branch on whether to use the new kernel. - if (canEncodeNewGEMM) { - // Instantiate the descriptor. - GEMMDescriptor gemmDesc; - gemmDesc.matrixDimensions = simd::uint3 { - params.M, - params.N, - params.K, - }; - switch (params.data_type) { - case MTL::DataTypeHalf: { - gemmDesc.memoryPrecisions = { - .A = GEMMOperandPrecision::FP16, - .B = GEMMOperandPrecision::FP16, - .C = GEMMOperandPrecision::FP16, - .bias = GEMMOperandPrecision::FP16, - }; - break; - } - case MTL::DataTypeFloat: { - gemmDesc.memoryPrecisions = { - .A = GEMMOperandPrecision::FP32, - .B = GEMMOperandPrecision::FP32, - .C = GEMMOperandPrecision::FP32, - .bias = GEMMOperandPrecision::FP32, - }; - break; - } - default: - CCV_NNC_MFA_PRECONDITION(false); - break; - } - gemmDesc.transposeState = simd::uchar3 { params.A_trans, params.B_trans, params.A_trans }; - gemmDesc.useBias = params.fused_bias; - - // Instantiate the kernel. - // - // TODO: Remove the autoreleasepool, once you confirm the caller always - // makes one. Or find a different solution, like spawning a pool inside - // of 'fetchKernel' when a new kernel variant is compiled. - auto pool = NS::AutoreleasePool::alloc()->init(); - auto &shaderCache = context->v2_cache; - DeviceProperties dprops = DeviceProperties(); - dprops.coreCount = 18; - auto pipelineValue = shaderCache.findKernel(gemmDesc, context->device.get(), dprops); - pool->drain(); - auto kernel = pipelineValue->kernel; - auto pipeline = pipelineValue->pipeline; - - // Allocate a new command. - auto encoder = command_batch->startCommand(); - encoder->setComputePipelineState(pipeline.get()); - encoder->setThreadgroupMemoryLength(kernel->threadgroupMemoryAllocation, 0); - - // Bind the function arguments. - encoder->useResource(tensors[0], MTL::ResourceUsageRead); - encoder->useResource(tensors[1], MTL::ResourceUsageRead); - encoder->useResource(tensors[2], MTL::ResourceUsageWrite); - if (num_tensors >= 4) { - encoder->useResource(tensors[3], MTL::ResourceUsageRead); - } - for (int i = 0; i < num_tensors; ++i) { - encoder->setBuffer(tensors[i], tensor_offsets[i], i); - } - - // Calculate the grid size. - auto ceilDivide = - [=](int64_t target, uint16_t granularity) -> int64_t { - return (target + int64_t(granularity) - 1) / int64_t(granularity); - }; - MTL::Size gridSize - (ceilDivide(int64_t(params.N), kernel->blockDimensions[1]), - ceilDivide(int64_t(params.M), kernel->blockDimensions[0]), - 1); - MTL::Size groupSize - (int64_t(kernel->threadgroupSize), 1, 1); - - // Dispatch the required number of threads. - encoder->dispatchThreadgroups(gridSize, groupSize); - - // Finish the command. - command_batch->finishCommand(encoder); - } else { - mfa::gemm::hash hash(params); - auto iterator = context->gemm_cache.map.find(hash); - if (iterator == context->gemm_cache.map.end()) { - mfa::precondition_failure("GEMM hash not cached.", __LINE__, __FILE__, __FUNCTION__); - } - - auto* pipeline = iterator->second; - auto encoder = command_batch->startCommand(); - encoder->setComputePipelineState(pipeline->pso.get()); - encoder->setThreadgroupMemoryLength(pipeline->threadgroup_memory_length, 0); - - encoder->useResource(tensors[0], MTL::ResourceUsageRead); - encoder->useResource(tensors[1], MTL::ResourceUsageRead); - encoder->useResource(tensors[2], MTL::ResourceUsageWrite); - if (num_tensors >= 4) { - encoder->useResource(tensors[3], MTL::ResourceUsageRead); + GEMMDescriptor gemmDesc; + gemmDesc.matrixDimensions = simd::uint3 { + params.M, + params.N, + params.K, + }; + switch (params.data_type) { + case MTL::DataTypeHalf: { + gemmDesc.memoryPrecisions = { + .A = GEMMOperandPrecision::FP16, + .B = GEMMOperandPrecision::FP16, + .C = GEMMOperandPrecision::FP16, + .bias = GEMMOperandPrecision::FP16, + }; + break; } - for (int i = 0; i < num_tensors; ++i) { - encoder->setBuffer(tensors[i], tensor_offsets[i], i); + case MTL::DataTypeFloat: { + gemmDesc.memoryPrecisions = { + .A = GEMMOperandPrecision::FP32, + .B = GEMMOperandPrecision::FP32, + .C = GEMMOperandPrecision::FP32, + .bias = GEMMOperandPrecision::FP32, + }; + break; } - - // Simple broadcasting rules; not yet support for NumPy broadcasting rules. + default: + CCV_NNC_MFA_PRECONDITION(false); + break; + } + gemmDesc.transposeState = simd::uchar3 { params.A_trans, params.B_trans, params.D_trans }; + gemmDesc.loadPreviousC = false; + gemmDesc.useBias = params.fused_bias; + if (params.batched) { simd::ushort4 num_batch_dims(0); simd::ulong4 batch_sizes(1); - if (params.batched) { - for (uint16_t operand = 0; operand < 4; ++operand) { - uint32_t* batch_dims; - if (operand == 0) { - batch_dims = params.batch_dims_a; - } else if (operand == 1) { - batch_dims = params.batch_dims_b; - } else if (operand == 2) { - // Skip the C operand. + for (uint16_t operand = 0; operand < 4; ++operand) { + uint32_t* batch_dims; + if (operand == 0) { + batch_dims = params.batch_dims_a; + } else if (operand == 1) { + batch_dims = params.batch_dims_b; + } else if (operand == 2) { + // Skip the C operand. + continue; + } else if (operand == 3) { + // Skip the D operand if unavailable. + if (!(params.fused_activation_function || params.fused_bias)) { continue; - } else if (operand == 3) { - // Skip the D operand if unavailable. - if (!(params.fused_activation_function || params.fused_bias)) { - continue; - } - batch_dims = params.batch_dims_d; - } - - for (int i = 0; i < CCV_NNC_MAX_DIM_ALLOC; ++i) { - if (batch_dims[i] == 0) { - break; - } - num_batch_dims[operand] += 1; - batch_sizes[operand] *= batch_dims[i]; } + batch_dims = params.batch_dims_d; } - - uint16_t data_type_size = 0; - switch (params.data_type) { - case MTL::DataTypeHalf: { - data_type_size = 2; - break; - } - case MTL::DataTypeFloat: { - data_type_size = 4; + + for (int i = 0; i < CCV_NNC_MAX_DIM_ALLOC; ++i) { + if (batch_dims[i] == 0) { break; } - default: - CCV_NNC_MFA_PRECONDITION(false); - break; - } - uint64_t byte_stride_a = hash.M * hash.K * data_type_size; - uint64_t byte_stride_b = hash.K * hash.N * data_type_size; - uint64_t byte_stride_c = hash.M * hash.N * data_type_size; - uint64_t byte_stride_d = (hash.D_trans ? hash.M : hash.N) * data_type_size; - if (batch_sizes[0] == 1) { - byte_stride_a = 0; - } - if (batch_sizes[1] == 1) { - byte_stride_b = 0; - } - if (batch_sizes[3] == 1) { - byte_stride_d = 0; - } - - const unsigned long batch_size = std::max(batch_sizes[0], batch_sizes[1]); - simd::ulong4 matrix_offsets[batch_size]; - for (int i = 0; i < batch_size; ++i) { - matrix_offsets[i] = simd::ulong4 { - i * byte_stride_a, - i * byte_stride_b, - i * byte_stride_c, - i * byte_stride_d, - }; - } - if (batch_size * 32 > 4096) { - auto buffer = context->device->newBuffer(matrix_offsets, batch_size * 32, MTL::ResourceStorageModeShared); - encoder->useResource(buffer, MTL::ResourceUsageRead); - encoder->setBuffer(buffer, 0, 10); - buffer->release(); - } else { - encoder->setBytes(matrix_offsets, batch_size * 32, 10); + num_batch_dims[operand] += 1; + batch_sizes[operand] *= batch_dims[i]; } } - - auto grid_size = pipeline->grid_size; - grid_size.depth = batch_sizes[0]; - encoder->dispatchThreadgroups(grid_size, pipeline->group_size); - command_batch->finishCommand(encoder); + + uint32_t stride_a = params.M * params.K; + uint32_t stride_b = params.K * params.N; + uint32_t stride_c = params.M * params.N; + uint32_t stride_d = params.D_trans ? params.M : params.N; + if (batch_sizes[0] == 1) { + stride_a = 0; + } + if (batch_sizes[1] == 1) { + stride_b = 0; + } + if (batch_sizes[3] == 1) { + stride_d = 0; + } + + const unsigned long batch_size = std::max(batch_sizes[0], batch_sizes[1]); + gemmDesc.batchDimension = batch_size; + simd::uint4 batchStrides; + batchStrides[0] = stride_a; + batchStrides[1] = stride_b; + batchStrides[2] = stride_c; + batchStrides[3] = stride_d; + gemmDesc.batchStrides = batchStrides; + } else { + gemmDesc.batchDimension = 1; + gemmDesc.batchStrides = std::nullopt; } + + // Instantiate the kernel. + // + // TODO: Remove the autoreleasepool, once you confirm the caller always + // makes one. Or find a different solution, like spawning a pool inside + // of 'fetchKernel' when a new kernel variant is compiled. + auto pool = NS::AutoreleasePool::alloc()->init(); + auto &shaderCache = context->v2_cache; + DeviceProperties dprops = DeviceProperties(); + auto pipelineValue = shaderCache.findKernel(gemmDesc, context->device.get(), dprops); + pool->drain(); + auto kernel = pipelineValue->kernel; + auto pipeline = pipelineValue->pipeline; + + // Allocate a new command. + auto encoder = command_batch->startCommand(); + encoder->setComputePipelineState(pipeline.get()); + encoder->setThreadgroupMemoryLength(kernel->threadgroupMemoryAllocation, 0); + + // Bind the function arguments. + encoder->useResource(tensors[0], MTL::ResourceUsageRead); + encoder->useResource(tensors[1], MTL::ResourceUsageRead); + encoder->useResource(tensors[2], MTL::ResourceUsageWrite); + if (num_tensors >= 4) { + encoder->useResource(tensors[3], MTL::ResourceUsageRead); + } + for (int i = 0; i < num_tensors; ++i) { + encoder->setBuffer(tensors[i], tensor_offsets[i], i); + } + + // Calculate the grid size. + auto ceilDivide = + [=](int64_t target, uint16_t granularity) -> int64_t { + return (target + int64_t(granularity) - 1) / int64_t(granularity); + }; + MTL::Size gridSize + (ceilDivide(int64_t(params.N), kernel->blockDimensions[1]), + ceilDivide(int64_t(params.M), kernel->blockDimensions[0]), + gemmDesc.batchDimension); + MTL::Size groupSize + (int64_t(kernel->threadgroupSize), 1, 1); + + // Dispatch the required number of threads. + encoder->dispatchThreadgroups(gridSize, groupSize); + + // Finish the command. + command_batch->finishCommand(encoder); } // MARK: - C++ diff --git a/lib/nnc/mfa/v2/GEMMDescriptor.cpp b/lib/nnc/mfa/v2/GEMMDescriptor.cpp index 503817504..298c8a841 100644 --- a/lib/nnc/mfa/v2/GEMMDescriptor.cpp +++ b/lib/nnc/mfa/v2/GEMMDescriptor.cpp @@ -9,6 +9,7 @@ bool GEMMDescriptor::operator==(const GEMMDescriptor& rhs) const { (batchDimension == rhs.batchDimension) && simd_all(matrixDimensions == rhs.matrixDimensions) && simd_all(leadingDimensions.value_or(simd::uint3(UINT32_MAX)) == rhs.leadingDimensions.value_or(simd::uint3(UINT32_MAX))) && + simd_all(batchStrides.value_or(simd::uint4(UINT32_MAX)) == rhs.batchStrides.value_or(simd::uint4(UINT32_MAX))) && memoryPrecisions == rhs.memoryPrecisions && registerPrecisionC == rhs.registerPrecisionC && simd_all(transposeState == rhs.transposeState) && @@ -27,6 +28,12 @@ std::size_t std::hash::operator()(const GEMMDescriptor& hash) co combine_32(seed, hash.leadingDimensions.value()[1]); combine_32(seed, hash.leadingDimensions.value()[2]); } + if (hash.batchStrides.has_value()) { + combine_32(seed, hash.batchStrides.value()[0]); + combine_32(seed, hash.batchStrides.value()[1]); + combine_32(seed, hash.batchStrides.value()[2]); + combine_32(seed, hash.batchStrides.value()[3]); + } combine_64(seed, pack_64(simd::ushort4 { hash.memoryPrecisions.A.value, hash.memoryPrecisions.B.value, hash.memoryPrecisions.C.value, hash.memoryPrecisions.bias.value })); combine_32(seed, pack_32(simd::uchar4 { hash.transposeState[0], hash.transposeState[1], hash.transposeState[2], 0 })); combine_32(seed, pack_32(simd::uchar4 { hash.loadPreviousC, hash.useBias, 0, 0 })); @@ -44,10 +51,8 @@ std::pair *> GEMMDescriptor::fin [=](GEMMKernelDescriptor descriptor) -> GEMMKernel* { auto iterator = libraryCache->find(descriptor); if (iterator != libraryCache->end()) { - std::cout << "Library cache hit." << std::endl; return iterator->second.get(); } else { - std::cout << "Library cache miss." << std::endl; GEMMKernel* kernel = new GEMMKernel(descriptor, device); (*libraryCache)[descriptor] = std::unique_ptr(kernel); return kernel; @@ -57,7 +62,6 @@ std::pair *> GEMMDescriptor::fin // WARNING: The owner must explicitly retain the compute pipeline. auto createPipeline = [=](MTL::Library* library) -> MTL::ComputePipelineState* { - std::cout << "Pipeline cache miss." << std::endl; // Set the function constants. auto constants = NS::TransferPtr (MTL::FunctionConstantValues::alloc()->init()); @@ -106,7 +110,19 @@ std::pair *> GEMMDescriptor::fin bool loadPreviousC = this->loadPreviousC; constants->setConstantValue(&loadPreviousC, MTL::DataTypeBool, 10); - + + bool batched = this->batchDimension > 1; + constants->setConstantValue(&batched, MTL::DataTypeBool, 11); + simd::uint4 batchStrides = this->batchStrides.value_or(simd::uint4(0)); + auto batchStrideA = batchStrides[0]; + auto batchStrideB = batchStrides[1]; + auto batchStrideC = batchStrides[2]; + auto batchStrideBias = batchStrides[3]; + constants->setConstantValue(&batchStrideA, MTL::DataTypeUInt, 15); + constants->setConstantValue(&batchStrideB, MTL::DataTypeUInt, 16); + constants->setConstantValue(&batchStrideC, MTL::DataTypeUInt, 17); + constants->setConstantValue(&batchStrideBias, MTL::DataTypeUInt, 18); + NS::String* swiftName = NS::String::string("gemm", NS::UTF8StringEncoding); NS::Error* error = nil; diff --git a/lib/nnc/mfa/v2/GEMMDescriptor.hpp b/lib/nnc/mfa/v2/GEMMDescriptor.hpp index b44df05d0..f0a5134c9 100644 --- a/lib/nnc/mfa/v2/GEMMDescriptor.hpp +++ b/lib/nnc/mfa/v2/GEMMDescriptor.hpp @@ -33,6 +33,8 @@ struct GEMMDescriptor { std::optional registerPrecisionC; std::optional leadingDimensions; + + std::optional batchStrides; simd::uchar3 transposeState; diff --git a/lib/nnc/mfa/v2/GEMMKernel.cpp b/lib/nnc/mfa/v2/GEMMKernel.cpp index 96cf87a33..8baf0931a 100644 --- a/lib/nnc/mfa/v2/GEMMKernel.cpp +++ b/lib/nnc/mfa/v2/GEMMKernel.cpp @@ -305,6 +305,18 @@ source += R"( ushort sidx [[simdgroup_index_in_threadgroup]], ushort lane_id [[thread_index_in_simdgroup]]) { + if (batched) { + A = A + A_batch_stride * gid.z; + B = B + B_batch_stride * gid.z; + C = C + C_batch_stride * gid.z; +)"; + if (useBias) { + source += R"( + bias = bias + bias_batch_stride * gid.z; +)"; + } +source += R"( + } ushort2 sid(sidx % {{SPLITS_N}}, sidx / {{SPLITS_N}}); ushort2 morton_offset = morton_order(lane_id); @@ -386,6 +398,14 @@ constant uint C_leading_dimension [[function_constant(7)]]; // Whether to load the previous value of C, and add it to the accumulator. constant bool load_previous_C [[function_constant(10)]]; +// Specify the batch / batch strides at PSO creation time. +constant bool batched [[function_constant(11)]]; + +constant uint A_batch_stride [[function_constant(15)]]; +constant uint B_batch_stride [[function_constant(16)]]; +constant uint C_batch_stride [[function_constant(17)]]; +constant uint bias_batch_stride [[function_constant(18)]]; + // Whether each matrix is transposed. constant bool A_trans = {{TRANSPOSE_STATE_A}}; constant bool B_trans = {{TRANSPOSE_STATE_B}}; diff --git a/lib/nnc/mfa/v2/ShaderCache.hpp b/lib/nnc/mfa/v2/ShaderCache.hpp index e10f932bf..ae312ce04 100644 --- a/lib/nnc/mfa/v2/ShaderCache.hpp +++ b/lib/nnc/mfa/v2/ShaderCache.hpp @@ -70,6 +70,10 @@ struct ShaderCache { pipelineCache->map[descriptor] = std::unique_ptr>(result.second); return result.second; } + + void evict() noexcept { + pipelineCache.clear(); + } }; #endif diff --git a/lib/nnc/mps/ccv_nnc_mps.m b/lib/nnc/mps/ccv_nnc_mps.m index 1f7f2a9c7..de817b65f 100644 --- a/lib/nnc/mps/ccv_nnc_mps.m +++ b/lib/nnc/mps/ccv_nnc_mps.m @@ -427,6 +427,7 @@ static inline void ccv_nnc_mps_graph_key_free(ccv_nnc_mps_graph_key_t key) void ccv_nnc_mps_clear_graph_executable_cache(void) { + ccv_nnc_mfa_clear_pipeline_cache(ccv_nnc_default_mfa_context()); if (!g_graph_executable_cache) return; khiter_t k;