From 299517ac080402f77e2fc6496d25d8200b8b63ec Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 9 Jan 2025 14:02:34 -0800 Subject: [PATCH] [ET-VK][ez][buck] Simplify test buck file ## Context The targets file for the op tests define a binary and test rule for each c++ file; instead of manually defining these rules each time, create a helper function to condense the code. Differential Revision: [D67992066](https://our.internmc.facebook.com/intern/diff/D67992066/) [ghstack-poisoned] --- backends/vulkan/test/op_tests/targets.bzl | 187 ++++++---------------- 1 file changed, 47 insertions(+), 140 deletions(-) diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index ab55d5beea..d26f1a805c 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -3,6 +3,44 @@ load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps") load("@fbsource//xplat/caffe2:pt_ops.bzl", "pt_operator_library") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +def define_test_targets(test_name, extra_deps = [], src_file = None, is_fbcode = False): + deps_list = [ + "//third-party/googletest:gtest_main", + "//executorch/backends/vulkan:vulkan_graph_runtime", + runtime.external_dep_location("libtorch"), + ] + extra_deps + + src_file_str = src_file if src_file else "{}.cpp".format(test_name) + + runtime.cxx_binary( + name = "{}_bin".format(test_name), + srcs = [ + src_file_str, + ], + compiler_flags = [ + "-Wno-unused-variable", + ], + define_static_target = False, + deps = deps_list, + ) + + runtime.cxx_test( + name = test_name, + srcs = [ + src_file_str, + ], + contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], + fbandroid_additional_loaded_sonames = [ + "torch-code-gen", + "vulkan_graph_runtime", + "vulkan_graph_runtime_shaderlib", + ], + platforms = [ANDROID], + use_instrumentation_test = True, + deps = deps_list, + ) + + def define_common_targets(is_fbcode = False): if is_fbcode: return @@ -82,19 +120,6 @@ def define_common_targets(is_fbcode = False): default_outs = ["."], ) - runtime.cxx_binary( - name = "compute_graph_op_tests_bin", - srcs = [ - ":generated_op_correctness_tests_cpp[op_tests.cpp]", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], - ) - runtime.cxx_binary( name = "compute_graph_op_benchmarks_bin", srcs = [ @@ -111,135 +136,17 @@ def define_common_targets(is_fbcode = False): ], ) - runtime.cxx_test( - name = "compute_graph_op_tests", - srcs = [ - ":generated_op_correctness_tests_cpp[op_tests.cpp]", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], + define_test_targets( + "compute_graph_op_tests", + src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]" ) - runtime.cxx_binary( - name = "sdpa_test_bin", - srcs = [ - "sdpa_test.cpp", - ], - compiler_flags = [ - "-Wno-unused-variable", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", - ], - ) - - runtime.cxx_test( - name = "sdpa_test", - srcs = [ - "sdpa_test.cpp", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", - "//executorch/extension/tensor:tensor", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_binary( - name = "linear_weight_int4_test_bin", - srcs = [ - "linear_weight_int4_test.cpp", - ], - compiler_flags = [ - "-Wno-unused-variable", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_test( - name = "linear_weight_int4_test", - srcs = [ - "linear_weight_int4_test.cpp", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", + define_test_targets( + "sdpa_test", + extra_deps = [ "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", "//executorch/extension/tensor:tensor", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_binary( - name = "rotary_embedding_test_bin", - srcs = [ - "rotary_embedding_test.cpp", - ], - compiler_flags = [ - "-Wno-unused-variable", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_test( - name = "rotary_embedding_test", - srcs = [ - "rotary_embedding_test.cpp", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/extension/tensor:tensor", - runtime.external_dep_location("libtorch"), - ], + ] ) + define_test_targets("linear_weight_int4_test") + define_test_targets("rotary_embedding_test")