Skip to content

Commit

Permalink
Update ewsum to use add shader.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 11, 2024
1 parent bb0cfeb commit 7832cde
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 61 deletions.
1 change: 1 addition & 0 deletions lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ static int _ccv_nnc_add_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint,
if (use_mfa) {
mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context);
ccv_nnc_mfa_add_params_t params = {
.args = 2,
.data_type = mtl_data_type,
.length = (uint32_t)length,
};
Expand Down
171 changes: 137 additions & 34 deletions lib/nnc/cmd/ew/mps/ccv_nnc_ew_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,149 @@ static int _ccv_nnc_ewsum_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hin
}
ccv_nnc_tensor_view_t* const c = (ccv_nnc_tensor_view_t*)outputs[0];
@autoreleasepool {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, k, hint, flags, inputs, input_size, outputs, output_size);
int* indices = (int*)ccv_nnc_stream_context_get_workspace(stream_context, (sizeof(int) + sizeof(MPSGraphTensorData*)) * input_size, CCV_TENSOR_CPU_MEMORY);
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
bool use_mfa = true;
const char *fallback_reason = NULL;
ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();

if (!ccv_nnc_mfa_context_supported(context) || (ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION)) {
use_mfa = false;
fallback_reason = "Disabled.";
}

if (use_mfa) {
for (z = 0; z < input_size; z++)
{
if (inputs[z]->info.datatype != CCV_16F && inputs[z]->info.datatype != CCV_32F) {
use_mfa = false;
fallback_reason = "Unsupported data type.";
break;
}
}
if (outputs[0]->info.datatype != CCV_16F && outputs[0]->info.datatype != CCV_32F) {
use_mfa = false;
fallback_reason = "Unsupported data type.";
}
}
if (use_mfa) {
int datatype = outputs[0]->info.datatype;
for (z = 0; z < input_size; z++)
{
if (inputs[z]->info.datatype != datatype) {
use_mfa = false;
fallback_reason = "Mismatched data type.";
break;
}
}
}
const size_t length = ccv_nnc_tensor_count(outputs[0]->info);
if (use_mfa) {
for (z = 0; z < input_size; z++)
{
if (ccv_nnc_tensor_count(inputs[z]->info) != length) {
use_mfa = false;
fallback_reason = "Broadcast semantics unsupported.";
break;
}
}
}
if (use_mfa) {
for (z = 0; z < input_size; z++)
{
if (!CCV_IS_TENSOR_CONTIGUOUS(inputs[z])) {
use_mfa = false;
fallback_reason = "Strided.";
break;
}
}
if (!CCV_IS_TENSOR_CONTIGUOUS(outputs[0])) {
use_mfa = false;
fallback_reason = "Strided.";
}
}
uint32_t mtl_data_type = UINT32_MAX;
if (use_mfa) {
switch (outputs[0]->info.datatype) {
case CCV_16F: {
mtl_data_type = 16;
break;
}
case CCV_32F: {
mtl_data_type = 3;
break;
}
default: {
use_mfa = false;
fallback_reason = "Unsupported data type.";
break;
}
}
}
if (input_size > 256) {
use_mfa = false;
fallback_reason = "Too many buffers.";
}
if (use_mfa) {
mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context);
ccv_nnc_mfa_add_params_t params = {
.args = input_size,
.data_type = mtl_data_type,
.length = (uint32_t)length,
};
ccv_nnc_mfa_prepare_add(context, params);

mtl_buffer_t* tensors[input_size + 2];
for (z = 0; z < input_size; z++)
tensors[z] = mpgetbuffer(inputs[z]);
tensors[input_size] = mpgetbuffer(outputs[0]);
tensors[input_size + 1] = NULL;
size_t tensor_offsets[input_size + 1];
for (z = 0; z < input_size; z++)
tensor_offsets[z] = inputs[z]->dataof;
tensor_offsets[input_size] = outputs[0]->dataof;
ccv_nnc_mfa_encode_add(context, params, command_batch, tensors, tensor_offsets);
ccv_nnc_stream_context_finish_command_batch(stream_context, command_batch);
} else {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, k, hint, flags, inputs, input_size, outputs, output_size);
int* indices = (int*)ccv_nnc_stream_context_get_workspace(stream_context, (sizeof(int) + sizeof(MPSGraphTensorData*)) * input_size, CCV_TENSOR_CPU_MEMORY);
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[k];
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
[inputTensors addObject:mps_input_a];
MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride);
[inputShapedTypes addObject:mps_a_shape];
MPSGraphTensor* mps_c = mps_a;
int z;
for (z = 0; z < input_size - 1; z++)
{
const ccv_nnc_tensor_view_t* const b = (const ccv_nnc_tensor_view_t*)(z >= k ? inputs[z + 1] : inputs[z]);
MPSGraphTensor* mps_input_b;
MPSGraphTensor* mps_b = ccv_nnc_mps_graph_tensor_input(graph, b, b->info.dim, b->stride, &mps_input_b);
[inputTensors addObject:mps_input_b];
MPSGraphShapedType* mps_b_shape = ccv_nnc_mps_graph_tensor_input_shape(b, b->info.dim, b->stride);
[inputShapedTypes addObject:mps_b_shape];
mps_c = [graph additionWithPrimaryTensor:mps_c secondaryTensor:mps_b name:nil];
}
[resultTensors addObject:mps_c];
});
MPSGraphTensorData** data = (MPSGraphTensorData**)(indices + input_size);
ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[k];
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
[inputTensors addObject:mps_input_a];
MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride);
[inputShapedTypes addObject:mps_a_shape];
MPSGraphTensor* mps_c = mps_a;
int z;
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
data[0] = data_a;
for (z = 0; z < input_size - 1; z++)
{
const ccv_nnc_tensor_view_t* const b = (const ccv_nnc_tensor_view_t*)(z >= k ? inputs[z + 1] : inputs[z]);
MPSGraphTensor* mps_input_b;
MPSGraphTensor* mps_b = ccv_nnc_mps_graph_tensor_input(graph, b, b->info.dim, b->stride, &mps_input_b);
[inputTensors addObject:mps_input_b];
MPSGraphShapedType* mps_b_shape = ccv_nnc_mps_graph_tensor_input_shape(b, b->info.dim, b->stride);
[inputShapedTypes addObject:mps_b_shape];
mps_c = [graph additionWithPrimaryTensor:mps_c secondaryTensor:mps_b name:nil];
MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, b->info.dim, b->stride);
data[z + 1] = data_b;
}
[resultTensors addObject:mps_c];
});
MPSGraphTensorData** data = (MPSGraphTensorData**)(indices + input_size);
ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[k];
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
data[0] = data_a;
for (z = 0; z < input_size - 1; z++)
{
const ccv_nnc_tensor_view_t* const b = (const ccv_nnc_tensor_view_t*)(z >= k ? inputs[z + 1] : inputs[z]);
MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, b->info.dim, b->stride);
data[z + 1] = data_b;
NSMutableArray<MPSGraphTensorData*>* feeds = [NSMutableArray new];
for (z = 0; z < input_size; z++)
[feeds addObject:data[indices[z]]];
ccv_nnc_mps_graph_executable_result(executable, command_buffer, feeds, &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1, 0);
[feeds release];
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
NSMutableArray<MPSGraphTensorData*>* feeds = [NSMutableArray new];
for (z = 0; z < input_size; z++)
[feeds addObject:data[indices[z]]];
ccv_nnc_mps_graph_executable_result(executable, command_buffer, feeds, &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1, 0);
[feeds release];
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
return CCV_NNC_EXEC_SUCCESS;
}
Expand Down
25 changes: 14 additions & 11 deletions lib/nnc/mfa/ccv_nnc_mfa_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ void ccv_nnc_mfa_encode_add(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_add_para
encoder->setBuffer(tensors[num_tensors], tensor_offsets[num_tensors], NS::UInteger(num_tensors));
num_tensors += 1;
}
CCV_NNC_MFA_PRECONDITION(num_tensors == 3);
CCV_NNC_MFA_PRECONDITION(num_tensors == 1 + params.args);

AddDescriptor descriptor;
descriptor.args = params.args;
descriptor.memoryPrecision = (params.data_type == MTL::DataTypeFloat) ? GEMMOperandPrecision::FP32 : GEMMOperandPrecision::FP16;
descriptor.length = params.length;

Expand All @@ -47,16 +48,18 @@ void ccv_nnc_mfa_encode_add(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_add_para

encoder->setComputePipelineState(pipeline.get());

if (tensors[0] == tensors[2]) {
encoder->useResource(tensors[0], MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
encoder->useResource(tensors[1], MTL::ResourceUsageRead);
} else if (tensors[1] == tensors[2]) {
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
} else {
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageRead);
encoder->useResource(tensors[2], MTL::ResourceUsageWrite);
int i;
int flag = 0;
for (i = 0; i < params.args; i++) {
if (tensors[i] == tensors[params.args]) {
encoder->useResource(tensors[i], MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
flag = 1;
} else {
encoder->useResource(tensors[i], MTL::ResourceUsageRead);
}
}
if (!flag) {
encoder->useResource(tensors[params.args], MTL::ResourceUsageWrite);
}

unsigned int count;
Expand Down
1 change: 1 addition & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

typedef struct {
uint64_t data_type;
uint8_t args;
uint32_t length;
} ccv_nnc_mfa_add_params_t;

Expand Down
4 changes: 3 additions & 1 deletion lib/nnc/mfa/v2/AddDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
bool AddDescriptor::operator==(const AddDescriptor& rhs) const {
return
memoryPrecision == rhs.memoryPrecision &&
args == rhs.args &&
value == rhs.value &&
length == rhs.length;
}
Expand All @@ -14,7 +15,7 @@ std::size_t std::hash<AddDescriptor>::operator()(const AddDescriptor& hash) cons
using namespace ccv::nnc::mfa::hash;
std::size_t seed = 0;
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.memoryPrecision.value, (unsigned int)hash.value }));
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.length, 0 }));
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.length, (unsigned int)hash.args }));
return seed;
}

Expand All @@ -35,6 +36,7 @@ std::pair<AddKernelDescriptor, PipelineValue<AddKernel> *> AddDescriptor::findKe
};

AddKernelDescriptor kernelDesc;
kernelDesc.args = args;
kernelDesc.value = value;
kernelDesc.memoryPrecision = memoryPrecision;

Expand Down
5 changes: 4 additions & 1 deletion lib/nnc/mfa/v2/AddDescriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
#include "GEMMOperandPrecision.hpp"

struct AddKernelDescriptor {
uint8_t args;
uint8_t value;
GEMMOperandPrecision memoryPrecision;
constexpr bool operator==(const AddKernelDescriptor &rhs) const { return value == rhs.value && memoryPrecision == rhs.memoryPrecision; }
constexpr bool operator==(const AddKernelDescriptor &rhs) const { return args == rhs.args && value == rhs.value && memoryPrecision == rhs.memoryPrecision; }
};

template<>
Expand All @@ -22,6 +23,8 @@ struct std::hash<AddKernelDescriptor>
struct AddKernel;

struct AddDescriptor {
uint8_t args;

uint8_t value;

GEMMOperandPrecision memoryPrecision;
Expand Down
49 changes: 35 additions & 14 deletions lib/nnc/mfa/v2/AddKernel.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#include "AddKernel.hpp"
#include "../ccv_nnc_mfa.hpp"
#include "CodeWriter.hpp"

#include <algorithm>

AddKernel::AddKernel(AddKernelDescriptor descriptor, MTL::Device *const device) {

args = descriptor.args;

value = descriptor.value;

memoryPrecision = descriptor.memoryPrecision;
Expand All @@ -29,61 +32,79 @@ unsigned short AddKernel::createThreadgroupMemoryAllocation() const noexcept {
}

std::string AddKernel::createSource() const noexcept {
std::string shader = createConstants() + "\n";
CodeWriter source;
source += createConstants() + "\n";
std::string buffers = "";
if (value == 0 || value == 1) {
for (int i = 1; i < args; i++) {
buffers += "device real4 *src" + std::to_string(i) + " [[buffer(" + std::to_string(i) + ")]],\n";
}
} else {
for (int i = 1; i < args; i++) {
buffers += "device real *src" + std::to_string(i) + " [[buffer(" + std::to_string(i) + ")]],\n";
}
}
source.SetValue("OTHER_SOURCE_BUFFERS", buffers);
source.SetValue("DESTINATION_INDEX", std::to_string(args));
std::string items = " + src1[idx]";
for (int i = 2; i < args; i++) {
items += " + src" + std::to_string(i) + "[idx]";
}
source.SetValue("OTHER_SOURCE_ITEMS", items);
if (value == 0) {
shader += R"(
source += R"(
#include <metal_stdlib>
using namespace metal;
kernel void add(
device real4 *src0 [[buffer(0)]],
device real4 *src1 [[buffer(1)]],
device real4 *destination [[buffer(2)]],
{{OTHER_SOURCE_BUFFERS}}
device real4 *destination [[buffer({{DESTINATION_INDEX}})]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
destination[idx] = src0[idx] + src1[idx];
destination[idx] = src0[idx]{{OTHER_SOURCE_ITEMS}};
}
)";
} else if (value == 1) {
shader += R"(
source += R"(
#include <metal_stdlib>
using namespace metal;
kernel void add(
device real4 *src0 [[buffer(0)]],
device real4 *src1 [[buffer(1)]],
device real4 *destination [[buffer(2)]],
{{OTHER_SOURCE_BUFFERS}}
device real4 *destination [[buffer({{DESTINATION_INDEX}})]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
if (idx >= count)
return;
destination[idx] = src0[idx] + src1[idx];
destination[idx] = src0[idx]{{OTHER_SOURCE_ITEMS}};
}
)";
} else {
shader += R"(
source += R"(
#include <metal_stdlib>
using namespace metal;
kernel void add(
device real *src0 [[buffer(0)]],
device real *src1 [[buffer(1)]],
device real *destination [[buffer(2)]],
{{OTHER_SOURCE_BUFFERS}}
device real *destination [[buffer({{DESTINATION_INDEX}})]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
if (idx >= count)
return;
destination[idx] = src0[idx] + src1[idx];
destination[idx] = src0[idx]{{OTHER_SOURCE_ITEMS}};
}
)";
}
return shader;
return source.ToString();
}

std::string AddKernel::createConstants() const noexcept {
Expand Down
Loading

0 comments on commit 7832cde

Please sign in to comment.