diff --git a/lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m b/lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m index 8b17840d3..aa1d580bf 100644 --- a/lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m +++ b/lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m @@ -150,6 +150,7 @@ static int _ccv_nnc_add_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, if (use_mfa) { mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context); ccv_nnc_mfa_add_params_t params = { + .args = 2, .data_type = mtl_data_type, .length = (uint32_t)length, }; diff --git a/lib/nnc/cmd/ew/mps/ccv_nnc_ew_mps.m b/lib/nnc/cmd/ew/mps/ccv_nnc_ew_mps.m index a93493ef1..95038303f 100644 --- a/lib/nnc/cmd/ew/mps/ccv_nnc_ew_mps.m +++ b/lib/nnc/cmd/ew/mps/ccv_nnc_ew_mps.m @@ -19,46 +19,149 @@ static int _ccv_nnc_ewsum_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hin } ccv_nnc_tensor_view_t* const c = (ccv_nnc_tensor_view_t*)outputs[0]; @autoreleasepool { - MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context); - ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, k, hint, flags, inputs, input_size, outputs, output_size); - int* indices = (int*)ccv_nnc_stream_context_get_workspace(stream_context, (sizeof(int) + sizeof(MPSGraphTensorData*)) * input_size, CCV_TENSOR_CPU_MEMORY); - MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray* inputTensors, NSMutableArray* inputShapedTypes, NSMutableArray* resultTensors) { + bool use_mfa = true; + const char *fallback_reason = NULL; + ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context(); + + if (!ccv_nnc_mfa_context_supported(context) || (ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION)) { + use_mfa = false; + fallback_reason = "Disabled."; + } + + if (use_mfa) { + for (z = 0; z < input_size; z++) + { + if (inputs[z]->info.datatype != CCV_16F && inputs[z]->info.datatype != CCV_32F) { + use_mfa = false; + fallback_reason = "Unsupported data type."; + break; + } + } + if (outputs[0]->info.datatype != CCV_16F && outputs[0]->info.datatype != CCV_32F) { + use_mfa = false; + fallback_reason = "Unsupported data type."; + } + } + if (use_mfa) { + int datatype = outputs[0]->info.datatype; + for (z = 0; z < input_size; z++) + { + if (inputs[z]->info.datatype != datatype) { + use_mfa = false; + fallback_reason = "Mismatched data type."; + break; + } + } + } + const size_t length = ccv_nnc_tensor_count(outputs[0]->info); + if (use_mfa) { + for (z = 0; z < input_size; z++) + { + if (ccv_nnc_tensor_count(inputs[z]->info) != length) { + use_mfa = false; + fallback_reason = "Broadcast semantics unsupported."; + break; + } + } + } + if (use_mfa) { + for (z = 0; z < input_size; z++) + { + if (!CCV_IS_TENSOR_CONTIGUOUS(inputs[z])) { + use_mfa = false; + fallback_reason = "Strided."; + break; + } + } + if (!CCV_IS_TENSOR_CONTIGUOUS(outputs[0])) { + use_mfa = false; + fallback_reason = "Strided."; + } + } + uint32_t mtl_data_type = UINT32_MAX; + if (use_mfa) { + switch (outputs[0]->info.datatype) { + case CCV_16F: { + mtl_data_type = 16; + break; + } + case CCV_32F: { + mtl_data_type = 3; + break; + } + default: { + use_mfa = false; + fallback_reason = "Unsupported data type."; + break; + } + } + } + if (input_size > 256) { + use_mfa = false; + fallback_reason = "Too many buffers."; + } + if (use_mfa) { + mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context); + ccv_nnc_mfa_add_params_t params = { + .args = input_size, + .data_type = mtl_data_type, + .length = (uint32_t)length, + }; + ccv_nnc_mfa_prepare_add(context, params); + + mtl_buffer_t* tensors[input_size + 2]; + for (z = 0; z < input_size; z++) + tensors[z] = mpgetbuffer(inputs[z]); + tensors[input_size] = mpgetbuffer(outputs[0]); + tensors[input_size + 1] = NULL; + size_t tensor_offsets[input_size + 1]; + for (z = 0; z < input_size; z++) + tensor_offsets[z] = inputs[z]->dataof; + tensor_offsets[input_size] = outputs[0]->dataof; + ccv_nnc_mfa_encode_add(context, params, command_batch, tensors, tensor_offsets); + ccv_nnc_stream_context_finish_command_batch(stream_context, command_batch); + } else { + MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context); + ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, k, hint, flags, inputs, input_size, outputs, output_size); + int* indices = (int*)ccv_nnc_stream_context_get_workspace(stream_context, (sizeof(int) + sizeof(MPSGraphTensorData*)) * input_size, CCV_TENSOR_CPU_MEMORY); + MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray* inputTensors, NSMutableArray* inputShapedTypes, NSMutableArray* resultTensors) { + ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[k]; + MPSGraphTensor* mps_input_a; + MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a); + [inputTensors addObject:mps_input_a]; + MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride); + [inputShapedTypes addObject:mps_a_shape]; + MPSGraphTensor* mps_c = mps_a; + int z; + for (z = 0; z < input_size - 1; z++) + { + const ccv_nnc_tensor_view_t* const b = (const ccv_nnc_tensor_view_t*)(z >= k ? inputs[z + 1] : inputs[z]); + MPSGraphTensor* mps_input_b; + MPSGraphTensor* mps_b = ccv_nnc_mps_graph_tensor_input(graph, b, b->info.dim, b->stride, &mps_input_b); + [inputTensors addObject:mps_input_b]; + MPSGraphShapedType* mps_b_shape = ccv_nnc_mps_graph_tensor_input_shape(b, b->info.dim, b->stride); + [inputShapedTypes addObject:mps_b_shape]; + mps_c = [graph additionWithPrimaryTensor:mps_c secondaryTensor:mps_b name:nil]; + } + [resultTensors addObject:mps_c]; + }); + MPSGraphTensorData** data = (MPSGraphTensorData**)(indices + input_size); ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[k]; - MPSGraphTensor* mps_input_a; - MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a); - [inputTensors addObject:mps_input_a]; - MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride); - [inputShapedTypes addObject:mps_a_shape]; - MPSGraphTensor* mps_c = mps_a; - int z; + MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride); + data[0] = data_a; for (z = 0; z < input_size - 1; z++) { const ccv_nnc_tensor_view_t* const b = (const ccv_nnc_tensor_view_t*)(z >= k ? inputs[z + 1] : inputs[z]); - MPSGraphTensor* mps_input_b; - MPSGraphTensor* mps_b = ccv_nnc_mps_graph_tensor_input(graph, b, b->info.dim, b->stride, &mps_input_b); - [inputTensors addObject:mps_input_b]; - MPSGraphShapedType* mps_b_shape = ccv_nnc_mps_graph_tensor_input_shape(b, b->info.dim, b->stride); - [inputShapedTypes addObject:mps_b_shape]; - mps_c = [graph additionWithPrimaryTensor:mps_c secondaryTensor:mps_b name:nil]; + MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, b->info.dim, b->stride); + data[z + 1] = data_b; } - [resultTensors addObject:mps_c]; - }); - MPSGraphTensorData** data = (MPSGraphTensorData**)(indices + input_size); - ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[k]; - MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride); - data[0] = data_a; - for (z = 0; z < input_size - 1; z++) - { - const ccv_nnc_tensor_view_t* const b = (const ccv_nnc_tensor_view_t*)(z >= k ? inputs[z + 1] : inputs[z]); - MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, b->info.dim, b->stride); - data[z + 1] = data_b; + NSMutableArray* feeds = [NSMutableArray new]; + for (z = 0; z < input_size; z++) + [feeds addObject:data[indices[z]]]; + ccv_nnc_mps_graph_executable_result(executable, command_buffer, feeds, &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1, 0); + [feeds release]; + ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer); } - NSMutableArray* feeds = [NSMutableArray new]; - for (z = 0; z < input_size; z++) - [feeds addObject:data[indices[z]]]; - ccv_nnc_mps_graph_executable_result(executable, command_buffer, feeds, &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1, 0); - [feeds release]; - ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer); } return CCV_NNC_EXEC_SUCCESS; } diff --git a/lib/nnc/mfa/ccv_nnc_mfa_add.cpp b/lib/nnc/mfa/ccv_nnc_mfa_add.cpp index b12eca206..a177158f3 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_add.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_add.cpp @@ -23,9 +23,10 @@ void ccv_nnc_mfa_encode_add(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_add_para encoder->setBuffer(tensors[num_tensors], tensor_offsets[num_tensors], NS::UInteger(num_tensors)); num_tensors += 1; } - CCV_NNC_MFA_PRECONDITION(num_tensors == 3); + CCV_NNC_MFA_PRECONDITION(num_tensors == 1 + params.args); AddDescriptor descriptor; + descriptor.args = params.args; descriptor.memoryPrecision = (params.data_type == MTL::DataTypeFloat) ? GEMMOperandPrecision::FP32 : GEMMOperandPrecision::FP16; descriptor.length = params.length; @@ -47,16 +48,18 @@ void ccv_nnc_mfa_encode_add(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_add_para encoder->setComputePipelineState(pipeline.get()); - if (tensors[0] == tensors[2]) { - encoder->useResource(tensors[0], MTL::ResourceUsageRead | MTL::ResourceUsageWrite); - encoder->useResource(tensors[1], MTL::ResourceUsageRead); - } else if (tensors[1] == tensors[2]) { - encoder->useResource(tensors[0], MTL::ResourceUsageRead); - encoder->useResource(tensors[1], MTL::ResourceUsageRead | MTL::ResourceUsageWrite); - } else { - encoder->useResource(tensors[0], MTL::ResourceUsageRead); - encoder->useResource(tensors[1], MTL::ResourceUsageRead); - encoder->useResource(tensors[2], MTL::ResourceUsageWrite); + int i; + int flag = 0; + for (i = 0; i < params.args; i++) { + if (tensors[i] == tensors[params.args]) { + encoder->useResource(tensors[i], MTL::ResourceUsageRead | MTL::ResourceUsageWrite); + flag = 1; + } else { + encoder->useResource(tensors[i], MTL::ResourceUsageRead); + } + } + if (!flag) { + encoder->useResource(tensors[params.args], MTL::ResourceUsageWrite); } unsigned int count; diff --git a/lib/nnc/mfa/ccv_nnc_mfa_add.hpp b/lib/nnc/mfa/ccv_nnc_mfa_add.hpp index 69a54f1cc..52ec6633d 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_add.hpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_add.hpp @@ -3,6 +3,7 @@ typedef struct { uint64_t data_type; + uint8_t args; uint32_t length; } ccv_nnc_mfa_add_params_t; diff --git a/lib/nnc/mfa/v2/AddDescriptor.cpp b/lib/nnc/mfa/v2/AddDescriptor.cpp index 1ffac4305..7750283d3 100644 --- a/lib/nnc/mfa/v2/AddDescriptor.cpp +++ b/lib/nnc/mfa/v2/AddDescriptor.cpp @@ -6,6 +6,7 @@ bool AddDescriptor::operator==(const AddDescriptor& rhs) const { return memoryPrecision == rhs.memoryPrecision && + args == rhs.args && value == rhs.value && length == rhs.length; } @@ -14,7 +15,7 @@ std::size_t std::hash::operator()(const AddDescriptor& hash) cons using namespace ccv::nnc::mfa::hash; std::size_t seed = 0; combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.memoryPrecision.value, (unsigned int)hash.value })); - combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.length, 0 })); + combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.length, (unsigned int)hash.args })); return seed; } @@ -35,6 +36,7 @@ std::pair *> AddDescriptor::findKe }; AddKernelDescriptor kernelDesc; + kernelDesc.args = args; kernelDesc.value = value; kernelDesc.memoryPrecision = memoryPrecision; diff --git a/lib/nnc/mfa/v2/AddDescriptor.hpp b/lib/nnc/mfa/v2/AddDescriptor.hpp index ebd1a2126..cf4bfd418 100644 --- a/lib/nnc/mfa/v2/AddDescriptor.hpp +++ b/lib/nnc/mfa/v2/AddDescriptor.hpp @@ -8,9 +8,10 @@ #include "GEMMOperandPrecision.hpp" struct AddKernelDescriptor { + uint8_t args; uint8_t value; GEMMOperandPrecision memoryPrecision; - constexpr bool operator==(const AddKernelDescriptor &rhs) const { return value == rhs.value && memoryPrecision == rhs.memoryPrecision; } + constexpr bool operator==(const AddKernelDescriptor &rhs) const { return args == rhs.args && value == rhs.value && memoryPrecision == rhs.memoryPrecision; } }; template<> @@ -22,6 +23,8 @@ struct std::hash struct AddKernel; struct AddDescriptor { + uint8_t args; + uint8_t value; GEMMOperandPrecision memoryPrecision; diff --git a/lib/nnc/mfa/v2/AddKernel.cpp b/lib/nnc/mfa/v2/AddKernel.cpp index d3516e8b2..63a2ccaa2 100644 --- a/lib/nnc/mfa/v2/AddKernel.cpp +++ b/lib/nnc/mfa/v2/AddKernel.cpp @@ -1,10 +1,13 @@ #include "AddKernel.hpp" #include "../ccv_nnc_mfa.hpp" +#include "CodeWriter.hpp" #include AddKernel::AddKernel(AddKernelDescriptor descriptor, MTL::Device *const device) { + args = descriptor.args; + value = descriptor.value; memoryPrecision = descriptor.memoryPrecision; @@ -29,61 +32,79 @@ unsigned short AddKernel::createThreadgroupMemoryAllocation() const noexcept { } std::string AddKernel::createSource() const noexcept { - std::string shader = createConstants() + "\n"; + CodeWriter source; + source += createConstants() + "\n"; + std::string buffers = ""; + if (value == 0 || value == 1) { + for (int i = 1; i < args; i++) { + buffers += "device real4 *src" + std::to_string(i) + " [[buffer(" + std::to_string(i) + ")]],\n"; + } + } else { + for (int i = 1; i < args; i++) { + buffers += "device real *src" + std::to_string(i) + " [[buffer(" + std::to_string(i) + ")]],\n"; + } + } + source.SetValue("OTHER_SOURCE_BUFFERS", buffers); + source.SetValue("DESTINATION_INDEX", std::to_string(args)); + std::string items = " + src1[idx]"; + for (int i = 2; i < args; i++) { + items += " + src" + std::to_string(i) + "[idx]"; + } + source.SetValue("OTHER_SOURCE_ITEMS", items); if (value == 0) { - shader += R"( + source += R"( #include using namespace metal; kernel void add( device real4 *src0 [[buffer(0)]], - device real4 *src1 [[buffer(1)]], - device real4 *destination [[buffer(2)]], + {{OTHER_SOURCE_BUFFERS}} + device real4 *destination [[buffer({{DESTINATION_INDEX}})]], uint3 tpig [[thread_position_in_grid]] ) { const uint idx = tpig.x; - destination[idx] = src0[idx] + src1[idx]; + destination[idx] = src0[idx]{{OTHER_SOURCE_ITEMS}}; } )"; } else if (value == 1) { - shader += R"( + source += R"( #include using namespace metal; kernel void add( device real4 *src0 [[buffer(0)]], - device real4 *src1 [[buffer(1)]], - device real4 *destination [[buffer(2)]], + {{OTHER_SOURCE_BUFFERS}} + device real4 *destination [[buffer({{DESTINATION_INDEX}})]], uint3 tpig [[thread_position_in_grid]] ) { const uint idx = tpig.x; if (idx >= count) return; - destination[idx] = src0[idx] + src1[idx]; + destination[idx] = src0[idx]{{OTHER_SOURCE_ITEMS}}; } )"; } else { - shader += R"( + source += R"( #include using namespace metal; kernel void add( device real *src0 [[buffer(0)]], - device real *src1 [[buffer(1)]], - device real *destination [[buffer(2)]], + {{OTHER_SOURCE_BUFFERS}} + device real *destination [[buffer({{DESTINATION_INDEX}})]], uint3 tpig [[thread_position_in_grid]] ) { const uint idx = tpig.x; if (idx >= count) return; - destination[idx] = src0[idx] + src1[idx]; + destination[idx] = src0[idx]{{OTHER_SOURCE_ITEMS}}; } )"; } - return shader; + return source.ToString(); } std::string AddKernel::createConstants() const noexcept { diff --git a/lib/nnc/mfa/v2/AddKernel.hpp b/lib/nnc/mfa/v2/AddKernel.hpp index 93c15cae6..2d5849c5d 100644 --- a/lib/nnc/mfa/v2/AddKernel.hpp +++ b/lib/nnc/mfa/v2/AddKernel.hpp @@ -15,6 +15,8 @@ struct AddKernel { /// The number of threads per group. MTL::Size threadgroupSize; + uint8_t args; + uint8_t value; GEMMOperandPrecision memoryPrecision;