From da6bb12f86bb8de694859f92d22347dd70cb7ab5 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Wed, 11 Dec 2024 18:35:58 +0100 Subject: [PATCH 01/39] Harmonize scripts to compute versions (#674) Harmonizes the scripts to compute the versions. JSON files are only written if `--write-json` is passed. Furthermore, `--version-suffix` can no longer be combined with other release types and gives full control over defining a suffix to the user. Similar changes are applied to the scripts used in IREE. --- .github/workflows/build_packages.yml | 6 +- .../python_deploy/compute_common_version.py | 68 ++++++++----------- .../python_deploy/compute_local_version.py | 55 +++++++++------ 3 files changed, 66 insertions(+), 63 deletions(-) diff --git a/.github/workflows/build_packages.yml b/.github/workflows/build_packages.yml index 9530cf4d6..e711dc7e2 100644 --- a/.github/workflows/build_packages.yml +++ b/.github/workflows/build_packages.yml @@ -67,9 +67,9 @@ jobs: id: version_local run: | echo "version_suffix=${version_suffix}" >> $GITHUB_OUTPUT - python3 build_tools/python_deploy/compute_local_version.py --version-suffix=${version_suffix} sharktank - python3 build_tools/python_deploy/compute_local_version.py --version-suffix=${version_suffix} shortfin - python3 build_tools/python_deploy/compute_common_version.py -rc --version-suffix=${version_suffix} --write-json + python3 build_tools/python_deploy/compute_local_version.py --version-suffix=${version_suffix} --write-json sharktank + python3 build_tools/python_deploy/compute_local_version.py --version-suffix=${version_suffix} --write-json shortfin + python3 build_tools/python_deploy/compute_common_version.py --version-suffix=${version_suffix} --write-json - name: Upload version_local.json files uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 diff --git a/build_tools/python_deploy/compute_common_version.py b/build_tools/python_deploy/compute_common_version.py index aa193bcc1..b9e7c438c 100755 --- a/build_tools/python_deploy/compute_common_version.py +++ b/build_tools/python_deploy/compute_common_version.py @@ -17,37 +17,28 @@ from pathlib import Path import json from datetime import datetime -import sys +import subprocess from packaging.version import Version parser = argparse.ArgumentParser() parser.add_argument("--write-json", action="store_true") -parser.add_argument("--version-suffix", action="store", type=str) -release_type = parser.add_mutually_exclusive_group() -release_type.add_argument("-stable", "--stable-release", action="store_true") # default +release_type = parser.add_mutually_exclusive_group(required=True) +release_type.add_argument("-stable", "--stable-release", action="store_true") release_type.add_argument("-rc", "--nightly-release", action="store_true") - +release_type.add_argument("-dev", "--development-release", action="store_true") +release_type.add_argument("--version-suffix", action="store", type=str) args = parser.parse_args() -if not (args.stable_release or args.nightly_release): - parser.print_usage(sys.stderr) - sys.stderr.write("error: A release type is required\n") - sys.exit(1) - -if args.stable_release and args.version_suffix: - sys.stderr.write("error: A version suffix is only supported for stable releases\n") - sys.exit(1) - THIS_DIR = Path(__file__).parent.resolve() REPO_ROOT = THIS_DIR.parent.parent -VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version.json" -VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version.json" -VERSION_FILE_LOCAL = REPO_ROOT / "shark-ai/version_local.json" +VERSION_FILE_SHARKTANK_PATH = REPO_ROOT / "sharktank/version.json" +VERSION_FILE_SHORTFIN_PATH = REPO_ROOT / "shortfin/version.json" +VERSION_FILE_LOCAL_PATH = REPO_ROOT / "shark-ai/version_local.json" def load_version_info(version_file): @@ -55,35 +46,36 @@ def load_version_info(version_file): return json.load(f) -def write_version_info(): - with open(VERSION_FILE_LOCAL, "w") as f: - json.dump(version_local, f, indent=2) +def write_version_info(version_file, version): + with open(version_file, "w") as f: + json.dump({"package-version": version}, f, indent=2) f.write("\n") -sharktank_version = load_version_info(VERSION_FILE_SHARKTANK) -SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version") -SHARKTANK_BASE_VERSION = Version(SHARKTANK_PACKAGE_VERSION).base_version +sharktank_version = load_version_info(VERSION_FILE_SHARKTANK_PATH) +sharktank_package_version = sharktank_version.get("package-version") +sharktank_base_version = Version(sharktank_package_version).base_version -shortfin_version = load_version_info(VERSION_FILE_SHORTFIN) -SHORTFIN_PACKAGE_VERSION = shortfin_version.get("package-version") -SHORTFIN_BASE_VERSION = Version(SHORTFIN_PACKAGE_VERSION).base_version +shortfin_version = load_version_info(VERSION_FILE_SHORTFIN_PATH) +shortfin_package_version = shortfin_version.get("package-version") +shortfin_base_version = Version(shortfin_package_version).base_version -if SHARKTANK_BASE_VERSION > SHORTFIN_BASE_VERSION: - COMMON_VERSION = SHARKTANK_BASE_VERSION +if sharktank_base_version > shortfin_base_version: + common_version = sharktank_base_version else: - COMMON_VERSION = SHORTFIN_BASE_VERSION + common_version = shortfin_base_version if args.nightly_release: - if args.version_suffix: - VERSION_SUFFIX = args.version_suffix - else: - VERSION_SUFFIX = "rc" + datetime.today().strftime("%Y%m%d") - - COMMON_VERSION += VERSION_SUFFIX + common_version += "rc" + datetime.today().strftime("%Y%m%d") +elif args.development_release: + common_version += ( + ".dev0+" + + subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() + ) +elif args.version_suffix: + common_version += args.version_suffix if args.write_json: - version_local = {"package-version": COMMON_VERSION} - write_version_info() + write_version_info(VERSION_FILE_LOCAL_PATH, common_version) -print(COMMON_VERSION) +print(common_version) diff --git a/build_tools/python_deploy/compute_local_version.py b/build_tools/python_deploy/compute_local_version.py index 0465fa443..3cb0fab34 100755 --- a/build_tools/python_deploy/compute_local_version.py +++ b/build_tools/python_deploy/compute_local_version.py @@ -6,50 +6,61 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # This scripts grabs the X.Y.Z[.dev]` version identifier from a -# `version.json` and writes the corresponding +# `version.json` and writes a version identifier for a stable, +# nightly or development release, or a release with an arbitrary # `X.Y.ZrcYYYYMMDD` version identifier to `version_local.json`. import argparse from pathlib import Path import json from datetime import datetime +import subprocess from packaging.version import Version parser = argparse.ArgumentParser() parser.add_argument("path", type=Path) -parser.add_argument("--version-suffix", action="store", type=str) +parser.add_argument("--write-json", action="store_true") + +release_type = parser.add_mutually_exclusive_group(required=True) +release_type.add_argument("-stable", "--stable-release", action="store_true") +release_type.add_argument("-rc", "--nightly-release", action="store_true") +release_type.add_argument("-dev", "--development-release", action="store_true") +release_type.add_argument("--version-suffix", action="store", type=str) + args = parser.parse_args() -VERSION_FILE = args.path / "version.json" -VERSION_FILE_LOCAL = args.path / "version_local.json" +VERSION_FILE_PATH = args.path / "version.json" +VERSION_FILE_LOCAL_PATH = args.path / "version_local.json" -def load_version_info(): - with open(VERSION_FILE, "rt") as f: +def load_version_info(version_file): + with open(version_file, "rt") as f: return json.load(f) -def write_version_info(): - with open(VERSION_FILE_LOCAL, "w") as f: - json.dump(version_local, f, indent=2) +def write_version_info(version_file, version): + with open(version_file, "w") as f: + json.dump({"package-version": version}, f, indent=2) f.write("\n") -version_info = load_version_info() - -if args.version_suffix: - VERSION_SUFFIX = args.version_suffix -else: - VERSION_SUFFIX = "rc" + datetime.today().strftime("%Y%m%d") - -PACKAGE_VERSION = version_info.get("package-version") -PACKAGE_BASE_VERSION = Version(PACKAGE_VERSION).base_version -PACKAGE_LOCAL_VERSION = PACKAGE_BASE_VERSION + VERSION_SUFFIX +version_info = load_version_info(VERSION_FILE_PATH) +package_version = version_info.get("package-version") +current_version = Version(package_version).base_version -version_local = {"package-version": PACKAGE_LOCAL_VERSION} +if args.nightly_release: + current_version += "rc" + datetime.today().strftime("%Y%m%d") +elif args.development_release: + current_version += ( + ".dev0+" + + subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() + ) +elif args.version_suffix: + current_version += args.version_suffix -write_version_info() +if args.write_json: + write_version_info(VERSION_FILE_LOCAL_PATH, current_version) -print(PACKAGE_LOCAL_VERSION) +print(current_version) From 1e26b205287d9a7153639ca0f66e0fcd51c5138e Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Wed, 11 Dec 2024 14:55:15 -0500 Subject: [PATCH 02/39] [tuner]: use translation_info binding (#669) This PR is relevant to the task in https://github.com/nod-ai/shark-ai/issues/453 : use IREE bindings for compilation info (incl., lowering_config and translation_info). Use translation_info from IREE python binding. --------- Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 106 +++++----------------------- tuner/tuner/candidate_gen_test.py | 102 +++++++++++++++++--------- tuner/tuner/common.py | 51 +++++++++---- tuner/tuner/common_test.py | 53 +++++++++----- tuner/tuner/dispatch_constraints.py | 18 +++-- tuner/tuner/dispatch_parser_test.py | 46 ++++++++---- 6 files changed, 203 insertions(+), 173 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index bc01bb709..a3252130e 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -51,6 +51,12 @@ def apply_configuration( ) = lowering_config.subgroup_count_mn workgroup_sizes = lowering_config.workgroup_tile_sizes reduction_sizes = lowering_config.reduction_tile_sizes + gpu_pipeline_options = configuration.translation_info.configuration[ + GPU_PIPELINE_OPTIONS_KEY + ] + waves_per_eu = configuration.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][ + WAVES_PER_EU_KEY + ] tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" @@ -63,11 +69,11 @@ def apply_configuration( expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") repl0 = f"" - repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' + repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.translation_info.workgroup_size))}] subgroup_size = {configuration.translation_info.subgroup_size},' repl2 = f"workgroup = {workgroup_sizes}" repl3 = f"reduction = {reduction_sizes}" - repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" - repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' + repl4 = f"gpu_pipeline_options = {gpu_pipeline_options}" + repl5 = f'"amdgpu-waves-per-eu" = {waves_per_eu}' new_mlir = "" for line in template: @@ -128,15 +134,6 @@ class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: - lowering_config = configuration.lowering_config - intrinsic = lowering_config.mma_kind - ( - subgroup_m_count, - subgroup_n_count, - ) = lowering_config.subgroup_count_mn - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) return f""" transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op @@ -145,13 +142,8 @@ def get_transform_function_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}>, - translation_info = #iree_codegen.translation_info - {extra_config}}}> + lowering_config = {configuration.lowering_config}, + translation_info = {configuration.translation_info} > -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param }} @@ -197,16 +189,6 @@ def get_transform_function_conv( filter = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{dynamic_batch_output_ty}>" - lowering_config = configuration.lowering_config - intrinsic = lowering_config.mma_kind - ( - subgroup_m_count, - subgroup_n_count, - ) = lowering_config.subgroup_count_mn - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - return f""" transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ @@ -217,13 +199,8 @@ def get_transform_function_conv( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}>, - translation_info = #iree_codegen.translation_info - {extra_config}}}> + lowering_config = {configuration.lowering_config}, + translation_info = {configuration.translation_info} > -> !transform.any_param transform.yield %conv, %config : !transform.any_op, !transform.any_param }} @@ -262,16 +239,6 @@ def get_transform_function_broadcast_rhs_mmt( functionName: str, configuration: Configuration, ) -> str: - lowering_config = configuration.lowering_config - intrinsic = lowering_config.mma_kind - ( - subgroup_m_count, - subgroup_n_count, - ) = lowering_config.subgroup_count_mn - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - lhs_dynamic_batch = problem_size.lhs_type lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() lhs_dynamic_batch.shape[0] = -1 @@ -284,13 +251,8 @@ def get_transform_function_broadcast_rhs_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}>, - translation_info = #iree_codegen.translation_info - {extra_config}}}> + lowering_config = {configuration.lowering_config}, + translation_info = {configuration.translation_info} > -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param }} @@ -351,16 +313,6 @@ def get_transform_function_batch_mmt( functionName: str, configuration: Configuration, ) -> str: - lowering_config = configuration.lowering_config - intrinsic = lowering_config.mma_kind - ( - subgroup_m_count, - subgroup_n_count, - ) = lowering_config.subgroup_count_mn - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - return f""" transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ %mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op @@ -369,13 +321,8 @@ def get_transform_function_batch_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}>, - translation_info = #iree_codegen.translation_info - {extra_config}}}> + lowering_config = {configuration.lowering_config}, + translation_info = {configuration.translation_info} > -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param }} @@ -421,16 +368,6 @@ def get_transform_function_batch_matmul( input1 = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{problem_size.res_type}>" - lowering_config = configuration.lowering_config - intrinsic = lowering_config.mma_kind - ( - subgroup_m_count, - subgroup_n_count, - ) = lowering_config.subgroup_count_mn - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - return f""" transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ @@ -441,13 +378,8 @@ def get_transform_function_batch_matmul( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}>, - translation_info = #iree_codegen.translation_info - {extra_config}}}> + lowering_config = {configuration.lowering_config}, + translation_info = {configuration.translation_info} > -> !transform.any_param transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param }} diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 45da323c5..7f104bcd9 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -14,6 +14,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore from . import candidate_gen from . import common @@ -56,14 +57,17 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=16, subgroup_n_count=16, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get(prefetch_shared_memory=True) + config_dict = common.get_translation_info_config(pipeline_options, 8) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 16, config_dict + ) config = common.Configuration( - subgroup_size=16, - workgroup_size=[16, 16, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( - prefetch_shared_memory=True - ), - waves_per_eu=8, ) problem_size = common.ProblemSize( @@ -118,16 +122,21 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=4, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get( + reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get( + iree_gpu.ReorderWorkgroupsStrategy.Transpose + ) + ) + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [256, 1, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( - reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get( - iree_gpu.ReorderWorkgroupsStrategy.Transpose - ) - ), - waves_per_eu=2, ) problem_size = common.ProblemSize( @@ -191,12 +200,17 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=4, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [256, 1, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( @@ -246,12 +260,17 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=2, subgroup_n_count=2, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [128, 2, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( @@ -304,12 +323,17 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=2, subgroup_n_count=2, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [128, 2, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) tf_mlir = candidate_gen.BatchMmtTuner().apply_params( @@ -360,12 +384,17 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=2, subgroup_n_count=2, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 4) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [128, 2, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=4, ) tf_mlir = candidate_gen.BatchMmtTuner().apply_params( @@ -440,12 +469,17 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=2, subgroup_n_count=2, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 4) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [128, 2, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=4, ) tf_mlir = candidate_gen.ContractionTuner( diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 0a2b03fd1..c683c5bdc 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -14,6 +14,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore class CommonTypes: @@ -112,11 +113,16 @@ def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: @dataclass class Configuration: - subgroup_size: int - workgroup_size: list[int] + translation_info: iree_codegen.TranslationInfoAttr lowering_config: iree_gpu.LoweringConfigAttr - gpu_pipeline_options: iree_gpu.PipelineOptionsAttr - waves_per_eu: int + + +# The key name for GPUPipelineOptionsAttr in the translation info config dictionary. +GPU_PIPELINE_OPTIONS_KEY = "gpu_pipeline_options" +# The key name for llvm_func_attrs attribute in the translation info config dictionary. +LLVM_FUNC_ATTRS_KEY = "llvm_func_attrs" +# The Key name for the 'amdgpu-waves-per-eu' within the llvm_func_attrs attribute. +WAVES_PER_EU_KEY = "amdgpu-waves-per-eu" def get_lowering_config( @@ -157,15 +163,34 @@ def get_lowering_config( return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) -def get_pipeline_config(configuration: Configuration) -> str: - extra_config = "" - pipeline_options = configuration.gpu_pipeline_options - if pipeline_options != iree_gpu.PipelineOptionsAttr.get(): - extra_config += f", gpu_pipeline_options = {pipeline_options}" - - if configuration.waves_per_eu != 2: - extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' - return extra_config +# Generate a config dictionary used in translation_info attribute. +def get_translation_info_config( + pipeline_options: iree_gpu.PipelineOptionsAttr, waves_per_eu: int +) -> ir.DictAttr: + """ + Example IR + translation_info = #iree_codegen.translation_info< + pipeline = LLVMGPUVectorDistribute workgroup_size = [512, 1, 1] subgroup_size = 64, + {gpu_pipeline_options = #iree_gpu.pipeline_options<...>, + llvm_func_attrs = {"amdgpu-waves-per-eu" = "3"} + } + > + """ + waves_per_eu_str = str(waves_per_eu) + + # Create the waves_per_eu dictionary attribute. + waves_per_eu_dict = ir.DictAttr.get( + {WAVES_PER_EU_KEY: ir.StringAttr.get(waves_per_eu_str)} + ) + + config_dict = ir.DictAttr.get( + { + GPU_PIPELINE_OPTIONS_KEY: pipeline_options, + LLVM_FUNC_ATTRS_KEY: waves_per_eu_dict, + } + ) + + return config_dict def read_input_mlir(filename: str) -> list[str]: diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 6d76c216f..af1e1bf9a 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -15,6 +15,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore @pytest.fixture @@ -84,27 +85,36 @@ def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=1, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 32, config_dict + ) config = common.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) - config1_str: str = common.get_pipeline_config(config) - assert config1_str == "" - - config.waves_per_eu = 4 - config2_str: str = common.get_pipeline_config(config) - assert config2_str == ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + config1_str: str = str( + config.translation_info.configuration[common.LLVM_FUNC_ATTRS_KEY] + ) + assert config1_str == '{"amdgpu-waves-per-eu" = "2"}' - config.gpu_pipeline_options = iree_gpu.PipelineOptionsAttr.get( - prefetch_shared_memory=True + pipeline_options = iree_gpu.PipelineOptionsAttr.get(prefetch_shared_memory=True) + config_dict = common.get_translation_info_config(pipeline_options, 4) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 32, config_dict ) - config3_str = common.get_pipeline_config(config) + config = common.Configuration( + translation_info=translation_info, + lowering_config=lowering_config, + ) + config2_str: str = str(config.translation_info.configuration) assert ( - config3_str - == ', gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + config2_str + == '{gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}}' ) @@ -207,12 +217,17 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: == "#iree_gpu.lowering_config<{reduction = [0, 0, 16], subgroup_m_count = 1 : i64, subgroup_n_count = 1 : i64, workgroup = [4, 8, 0]}>" ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 32, config_dict + ) config = common.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) assert config.lowering_config.mma_kind is None diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index f86523389..8ba202310 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -244,13 +244,21 @@ def generate_solutions( subgroup_m_count=lookup(sg_m_cnt), subgroup_n_count=lookup(sg_n_cnt), ) - config = Configuration( - lookup(subgroup_size), + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = get_translation_info_config( + pipeline_options, lookup(waves_per_eu) + ) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, + None, [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - lowering_config, - iree_gpu.PipelineOptionsAttr.get(), - lookup(waves_per_eu), + lookup(subgroup_size), + config_dict, ) + config = Configuration(translation_info, lowering_config) solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) i += 1 yield config diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index db8c4a7da..a63576808 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -15,6 +15,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import func # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore from . import common from . import dispatch_parser @@ -50,12 +51,17 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=4, ) - config = dispatch_parser.Configuration( - subgroup_size=0, - workgroup_size=[], + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 0) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [], 0, config_dict + ) + config = common.Configuration( + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=0, ) lowering_config = config.lowering_config assert lowering_config.workgroup_tile_sizes == [128, 320, 0] @@ -73,12 +79,17 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=4, ) - config = dispatch_parser.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 1) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [256, 1, 1], 64, config_dict + ) + config = common.Configuration( + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=1, ) assert config.lowering_config.workgroup_tile_sizes == [1, 1, 464, 320, 1, 1, 0] assert config.lowering_config.reduction_tile_sizes == [0, 0, 0, 0, 0, 0, 16] @@ -95,12 +106,17 @@ def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=1, ) - config = dispatch_parser.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 32, config_dict + ) + config = common.Configuration( + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) assert dispatch_parser.get_contract_workgroup_sizes(config, "mnk") == [4, 8, 0] assert dispatch_parser.get_contract_reduction_sizes(config, "mnk") == [0, 0, 16] From 7e62c25a3f307a9a0e3191fc9ebdd875aacab1fc Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:12:42 -0600 Subject: [PATCH 03/39] loop over cache_partitions to enable fusion (#677) Co-authored-by: Rob Suderman --- sharktank/sharktank/layers/kv_cache.py | 52 ++++++++------------------ 1 file changed, 16 insertions(+), 36 deletions(-) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index c73b7a8f4..46e94ff90 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -447,30 +447,25 @@ def write_timestep( bs, *_ = seq_positions.shape assert len(cache_partitions) == self.cache_partition_count - partition_count = len(cache_partitions) + # [bs, 1, atten_head_count, attn_head_dim] + for idx, cache_partition in enumerate(cache_partitions): + # [bs, 1] + page_index = seq_positions // self.block_seq_stride - # [bs, partitions, atten_head_count, attn_head_dim] - cache_partitions = ops.cat(cache_partitions, dim=1) + page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1)) + page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1) - # [bs, 1] - page_index = seq_positions // self.block_seq_stride + # [1, 1] + partitions = torch.tensor(idx).unsqueeze(0) - page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1)) - page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1) - - # [1, partitions] - partitions = torch.arange(0, self.cache_partition_count).unsqueeze(0) - - # [bs, partitions] - page_id = page_id.repeat(1, partition_count) - transformer_block = torch.full( - (bs, partition_count), transformer_block_index, device=device - ) - page_offset = page_offset.repeat(1, partition_count) - partitions = partitions.repeat(bs, 1) + # [bs, 1] + transformer_block = torch.full( + (bs, 1), transformer_block_index, device=device + ) + partitions = partitions.repeat(bs, 1) - indices = (page_id, transformer_block, partitions, page_offset) - page_table.index_put_(indices=indices, values=cache_partitions) + indices = (page_id, transformer_block, partitions, page_offset) + page_table.index_put_(indices=indices, values=cache_partition) return @@ -490,14 +485,6 @@ def write( page_table = self.unflatten_page_table(state) # 6D bs, block_seq_len, *_ = page_ids.shape - # Blocks dim 1,2 according to the configured block stride. - blocked_shape = [ - bs, - block_seq_len, - self.block_seq_stride, - self.attn_head_count, - self.attn_head_dim, - ] # Reshape the page cache into sub-blocks so that we can index at the # granularity of the transformer_block and cache partition. @@ -513,21 +500,14 @@ def write( transformer_block_index * transformer_block_stride ) - part_block_views = [] - subblock_ids_kv = [] for index, partition in enumerate(cache_partitions): part_block_view = partition.unflatten( 1, (block_seq_len, self.block_seq_stride) ) part_block_view = part_block_view.flatten(0, 1) - part_block_views.append(part_block_view) subblock_ids = ( (base_subblock_ids + index) if index > 0 else base_subblock_ids ).flatten(0, 1) - subblock_ids_kv.append(subblock_ids) - - subblock_ids = ops.cat(subblock_ids_kv) - part_block_view = ops.cat(part_block_views, dim=0) - subblock_table.index_copy_(0, subblock_ids, part_block_view) + subblock_table.index_copy_(0, subblock_ids, part_block_view) From ffb870f5a1e17fa95f56ded016c3141947d31b9d Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 11 Dec 2024 23:08:47 -0500 Subject: [PATCH 04/39] Shortfin LLM Debug Ergonomics: one flag to dump them all (#668) Thanks @stbaione for originally writing a lot of the debug code in here. I organized his debug logging & dumping into a separate file & made the output easier to deal with. --- .../shortfin_apps/llm/components/generate.py | 6 +- .../shortfin_apps/llm/components/messages.py | 3 +- .../shortfin_apps/llm/components/service.py | 17 +- .../llm/components/service_debug_dumper.py | 216 ++++++++++++++++++ 4 files changed, 239 insertions(+), 3 deletions(-) create mode 100644 shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py diff --git a/shortfin/python/shortfin_apps/llm/components/generate.py b/shortfin/python/shortfin_apps/llm/components/generate.py index 698f779fb..9e9fea692 100644 --- a/shortfin/python/shortfin_apps/llm/components/generate.py +++ b/shortfin/python/shortfin_apps/llm/components/generate.py @@ -49,7 +49,11 @@ def __init__( self.eos_token_id = eos_token_id async def run(self): - exec = InferenceExecRequest(InferencePhase.PREFILL, self.input_token_ids) + exec = InferenceExecRequest( + phase=InferencePhase.PREFILL, + input_token_ids=self.input_token_ids, + rid=self.gen_req.rid, + ) try: self.client.batcher.submit(exec) await exec.done diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index 9e2ab7179..724f71569 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -21,12 +21,13 @@ class InferencePhase(Enum): class InferenceExecRequest(sf.Message): """Performs a prefill operation.""" - def __init__(self, phase: InferencePhase, input_token_ids: list[int]): + def __init__(self, phase: InferencePhase, input_token_ids: list[int], rid=None): super().__init__() self.phase = phase self.start_position: int = 0 self.input_token_ids = input_token_ids self.done = sf.VoidFuture() + self.rid = rid # Response control. # If True, return all sequence position logits. If False, return only diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 71d54c234..9ef8b5c4d 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -29,6 +29,9 @@ isolation.name.lower(): isolation for isolation in sf.ProgramIsolation } +import os +from .service_debug_dumper import SERVICE_DEBUG_DUMPER + class GenerateService: """Top level service interface for generating text against a model.""" @@ -438,7 +441,19 @@ async def run(self): fn, "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(args)]), ) - # Invoke. Logits are of shape [bs, bsl, d]. + + # pre-invocation args dump + if os.getenv("SHORTFIN_DEBUG_LLM_SERVICE", "False").lower() in ( + "true", + "yes", + "1", + "y", + ): + await SERVICE_DEBUG_DUMPER.pre_invocation_debug_dump( + executor=self, local_vars=locals() + ) + + # Invoke VMFB. Logits are of shape [bs, bsl, d]. (logits,) = await fn(*args, fiber=self.fiber) # publish cache pages diff --git a/shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py b/shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py new file mode 100644 index 000000000..ae492eab3 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py @@ -0,0 +1,216 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import logging +from datetime import datetime +from pathlib import Path +import json +import numpy as np +import pandas as pd +from typing import Dict, Any +from pprint import pformat + +logger = logging.getLogger(__name__) + + +class ServiceDebugDumper: + def __init__(self): + """Initialize debug service with a new dump directory for this session.""" + self.dump_id = 0 + self.boot_timestamp = datetime.now().isoformat() + self.debug_data_dir = Path.home() / ".shortfin/debug/" + self.dump_dir = ( + self.debug_data_dir / "llm_service_invocation_dumps" / self.boot_timestamp + ) + self.dump_dir.mkdir(parents=True, exist_ok=False) + logger.info( + f"[debug_service.py] Please find debug dumps for service.py in {self.dump_dir}" + ) + + async def pre_invocation_debug_dump( + self, executor: "InferenceExecutorProcess", local_vars: Dict[str, Any] + ): + """Comprehensive debug dump before inference invocation.""" + # Extract variables from locals + is_decode = local_vars["is_decode"] + device0 = local_vars["device0"] + fn = local_vars["fn"] + req_bs = local_vars["req_bs"] + bsl = local_vars["bsl"] + seq_stride = local_vars["seq_stride"] + block_count = local_vars["block_count"] + req_count = local_vars["req_count"] + tokens = local_vars["tokens"] + start_positions = local_vars.get("start_positions") + seq_lens = local_vars["seq_lens"] + seq_block_ids = local_vars["seq_block_ids"] + args = local_vars["args"] + + phase = executor.phase + exec_requests = executor.exec_requests + model_params = executor.service.model_params + + dump_path = self.dump_dir / f"{self.dump_id}" + dump_path.mkdir(parents=True, exist_ok=True) + + # Prepare debug info dictionary + debug_info = { + "metadata": { + "dump_id": self.dump_id, + "dump_timestamp": datetime.now().isoformat(), + "phase": str(phase), + "is_decode": is_decode, + "device": str(device0), + "function": str(fn), + }, + "batch_info": { + "request_batch_size": req_bs, + "block_sequence_length": int(bsl), + "sequence_stride": seq_stride, + "block_count": block_count, + "actual_request_count": req_count, + }, + "requests": [ + { + "index": i, + "start_position": req.start_position, + "rid": req.rid, + "input_token_ids": req.input_token_ids.tolist() + if hasattr(req.input_token_ids, "tolist") + else list(req.input_token_ids), + "input_length": len(req.input_token_ids), + "cache_pages": req.cache_page_indices(block_count), + } + for i, req in enumerate(exec_requests) + ], + "tensor_shapes": { + "tokens": tokens.shape, + **({"start_positions": start_positions.shape} if is_decode else {}), + "seq_lens": seq_lens.shape, + "seq_block_ids": seq_block_ids.shape, + }, + "tensor_values": { + "tokens": tokens.for_transfer().items.tolist() + if hasattr(tokens.for_transfer().items, "tolist") + else list(tokens.for_transfer().items), + **( + { + "start_positions": start_positions.for_transfer().items.tolist() + if hasattr(start_positions.for_transfer().items, "tolist") + else list(start_positions.for_transfer().items) + } + if is_decode + else {} + ), + "sequence_lengths": seq_lens.for_transfer().items.tolist() + if hasattr(seq_lens.for_transfer().items, "tolist") + else list(seq_lens.for_transfer().items), + "sequence_block_ids": seq_block_ids.for_transfer().items.tolist() + if hasattr(seq_block_ids.for_transfer().items, "tolist") + else list(seq_block_ids.for_transfer().items), + }, + "model_config": { + "prefill_batch_sizes": model_params.prefill_batch_sizes, + "decode_batch_sizes": model_params.decode_batch_sizes, + "attn_dtype": str(model_params.attn_dtype), + "paged_kv_cache": { + "device_block_count": model_params.paged_kv_cache.device_block_count, + "block_seq_stride": model_params.paged_kv_cache.block_seq_stride, + "prefix_sharing_algorithm": model_params.paged_kv_cache.prefix_sharing_algorithm, + }, + }, + } + + # Save debug info as JSON + with open(dump_path / "info.json", "w") as f: + json.dump(debug_info, f, indent=2) + + # Save program arguments + path = dump_path + args_np = [] + for i, a in enumerate(args): + host_array = a.for_transfer() + host_array.copy_from(a) + await a.device + args_np.append(np.array(host_array)) + + # Save binary numpy arrays + for i, arr in enumerate(args_np): + np.save(path / f"{i}.npy", arr) + + # Generate human-readable report + with open(path / "saved_program_args.txt", "w") as f: + for i, arr in enumerate(args_np): + f.write(f"\n{'='*80}\n") + f.write(f"{i}.npy:\n") + f.write(f"{'='*80}\n\n") + + # Basic info + f.write(f"Shape: {arr.shape}\n") + f.write(f"Dtype: {arr.dtype}\n") + f.write(f"Total elements: {arr.size}\n") + f.write(f"Dimensions: {arr.ndim}\n\n") + + # Stats + f.write("Statistics:\n") + nan_count = np.count_nonzero(np.isnan(arr)) + inf_count = np.count_nonzero(np.isinf(arr)) + f.write(f"- NaN count: {nan_count}\n") + f.write(f"- Inf count: {inf_count}\n") + + if nan_count == 0 and inf_count == 0: + f.write(f"- Min: {np.min(arr)}\n") + f.write(f"- Max: {np.max(arr)}\n") + f.write(f"- Mean: {np.mean(arr):.6f}\n") + f.write(f"- Median: {np.median(arr):.6f}\n") + f.write(f"- Range: {np.ptp(arr)}\n") + try: + mode = pd.Series(arr.flatten()).mode().iloc[0] + f.write(f"- Mode: {mode}\n") + except: + f.write("- Mode: Unable to compute\n") + + if np.issubdtype(arr.dtype, np.number): + try: + hist, bins = np.histogram(arr.flatten(), bins="auto") + f.write("\nHistogram:\n") + f.write( + "Bins: " + + pformat(bins.tolist(), width=80, compact=True) + + "\n" + ) + f.write( + "Counts: " + + pformat(hist.tolist(), width=80, compact=True) + + "\n" + ) + except Exception as e: + f.write(f"\nHistogram computation failed: {str(e)}\n") + else: + f.write("Skipping additional statistics due to NaN/Inf values\n") + + f.write("\nArray contents:\n") + if arr.size <= 64: + formatted = pformat(arr.tolist(), width=80, compact=True) + f.write(formatted + "\n") + else: + f.write("\nFirst 5 elements:\n") + f.write( + pformat(arr.flatten()[:5].tolist(), width=80, compact=True) + + "\n" + ) + f.write("\nLast 5 elements:\n") + f.write( + pformat(arr.flatten()[-5:].tolist(), width=80, compact=True) + + "\n" + ) + + self.dump_id += 1 + + +# Create single instance +SERVICE_DEBUG_DUMPER = ServiceDebugDumper() From d279afff48c56ab5e0e6e5ebe2717ccf0b26ee50 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Wed, 11 Dec 2024 23:51:09 -0500 Subject: [PATCH 05/39] [tuner]: use compilation_info binding (#678) This PR is relevant to the task in https://github.com/nod-ai/shark-ai/issues/453 : use IREE bindings for compilation info (incl., lowering_config and translation_info). Retire data class `configuration` and use the `compilation_info` from IREE python binding. Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 105 +++++++++++++--------------- tuner/tuner/candidate_gen_test.py | 51 ++++++-------- tuner/tuner/common.py | 6 -- tuner/tuner/common_test.py | 23 +++--- tuner/tuner/dispatch_constraints.py | 8 ++- tuner/tuner/dispatch_parser.py | 4 +- tuner/tuner/dispatch_parser_test.py | 87 +++++++++++++++++------ tuner/tuner/libtuner.py | 2 +- 8 files changed, 157 insertions(+), 129 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index a3252130e..ed150bfec 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -41,9 +41,9 @@ def apply_configuration( template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: - lowering_config = configuration.lowering_config + lowering_config = compilation_info.lowering_config intrinsic = lowering_config.mma_kind ( subgroup_m_count, @@ -51,13 +51,13 @@ def apply_configuration( ) = lowering_config.subgroup_count_mn workgroup_sizes = lowering_config.workgroup_tile_sizes reduction_sizes = lowering_config.reduction_tile_sizes - gpu_pipeline_options = configuration.translation_info.configuration[ + gpu_pipeline_options = compilation_info.translation_info.configuration[ GPU_PIPELINE_OPTIONS_KEY ] - waves_per_eu = configuration.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][ + waves_per_eu = compilation_info.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][ WAVES_PER_EU_KEY ] - tune_logger.info(f"Applying: {configuration}") + tune_logger.info(f"Applying: {compilation_info}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" ) @@ -69,7 +69,7 @@ def apply_configuration( expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") repl0 = f"" - repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.translation_info.workgroup_size))}] subgroup_size = {configuration.translation_info.subgroup_size},' + repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, compilation_info.translation_info.workgroup_size))}] subgroup_size = {compilation_info.translation_info.subgroup_size},' repl2 = f"workgroup = {workgroup_sizes}" repl3 = f"reduction = {reduction_sizes}" repl4 = f"gpu_pipeline_options = {gpu_pipeline_options}" @@ -101,7 +101,7 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: """Apply parameter transformations to the operation.""" pass @@ -132,7 +132,10 @@ def find_handler(self, op_name: str) -> DispatchTuner: class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration + self, + problem_size: ProblemSize, + functionName: str, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: return f""" transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ @@ -141,10 +144,7 @@ def get_transform_function_mmt( %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}, - translation_info = {configuration.translation_info} - > -> !transform.any_param + %config = transform.param.constant {compilation_info} -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param }} """ @@ -153,21 +153,23 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( self.get_transform_function_mmt( - problem_size, f"match_mmt_{M}x{N}x{K}", configuration + problem_size, f"match_mmt_{M}x{N}x{K}", compilation_info ), "// ", ) modified += apply_configuration( template, - configuration, + compilation_info, ) embeddable = indent( - self.get_transform_function_mmt(problem_size, f"match_op", configuration), + self.get_transform_function_mmt( + problem_size, f"match_op", compilation_info + ), " ", ) return MLIRTransformation(template, modified, embeddable) @@ -175,7 +177,10 @@ def apply_params( class ConvTuner(DispatchTuner, ConvParser): def get_transform_function_conv( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration + self, + problem_size: ProblemSize, + functionName: str, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: dynamic_batch_input_ty = problem_size.lhs_type dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() @@ -198,10 +203,7 @@ def get_transform_function_conv( ins(%lhs, %rhs : {input}, {filter}) outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}, - translation_info = {configuration.translation_info} - > -> !transform.any_param + %config = transform.param.constant {compilation_info} -> !transform.any_param transform.yield %conv, %config : !transform.any_op, !transform.any_param }} """ @@ -210,23 +212,25 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: conv_dims = ConvDimInfo.from_problem_size(problem_size) modified = indent( self.get_transform_function_conv( problem_size, f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", - configuration, + compilation_info, ), "// ", ) modified += apply_configuration( template, - configuration, + compilation_info, ) embeddable = indent( - self.get_transform_function_conv(problem_size, f"match_op", configuration), + self.get_transform_function_conv( + problem_size, f"match_op", compilation_info + ), " ", ) return MLIRTransformation(template, modified, embeddable) @@ -237,7 +241,7 @@ def get_transform_function_broadcast_rhs_mmt( self, problem_size: ProblemSize, functionName: str, - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: lhs_dynamic_batch = problem_size.lhs_type lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() @@ -250,10 +254,7 @@ def get_transform_function_broadcast_rhs_mmt( %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}, - translation_info = {configuration.translation_info} - > -> !transform.any_param +%config = transform.param.constant {compilation_info} -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param }} """ @@ -262,23 +263,23 @@ def apply_params_broadcast_rhs_mmt( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration + problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", compilation_info ), "// ", ) modified += apply_configuration( template, - configuration, + compilation_info, ) embeddable = indent( self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_op", configuration + problem_size, f"match_op", compilation_info ), " ", ) @@ -288,11 +289,11 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: if self.is_broadcast_rhs_mmt(template): return self.apply_params_broadcast_rhs_mmt( - problem_size, template, configuration + problem_size, template, compilation_info ) # TODO: Generate transform function. @@ -300,7 +301,7 @@ def apply_params( template, apply_configuration( template, - configuration, + compilation_info, ), "", ) @@ -311,7 +312,7 @@ def get_transform_function_batch_mmt( self, problem_size: ProblemSize, functionName: str, - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: return f""" transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ @@ -320,10 +321,7 @@ def get_transform_function_batch_mmt( %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}, - translation_info = {configuration.translation_info} - > -> !transform.any_param +%config = transform.param.constant {compilation_info} -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param }} """ @@ -332,24 +330,24 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: M, N, K = problem_size.MNK B = problem_size.matmul_size.B modified = indent( self.get_transform_function_batch_mmt( - problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration + problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", compilation_info ), "// ", ) modified += apply_configuration( template, - configuration, + compilation_info, ) embeddable = indent( self.get_transform_function_batch_mmt( - problem_size, f"match_op", configuration + problem_size, f"match_op", compilation_info ), " ", ) @@ -362,7 +360,7 @@ def get_transform_function_batch_matmul( problem_size: ProblemSize, tile_dims: str, functionName: str, - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> str: input0 = f"tensor<{problem_size.lhs_type}>" input1 = f"tensor<{problem_size.rhs_type}>" @@ -377,10 +375,7 @@ def get_transform_function_batch_matmul( ins(%lhs, %rhs : {input0}, {input1}) outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}, - translation_info = {configuration.translation_info} - > -> !transform.any_param + %config = transform.param.constant {compilation_info} -> !transform.any_param transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param }} """ @@ -389,7 +384,7 @@ def apply_params( self, problem_size: ProblemSize, template: list[str], - configuration: Configuration, + compilation_info: iree_codegen.CompilationInfoAttr, ) -> MLIRTransformation: M, N, K = problem_size.MNK modified = indent( @@ -397,18 +392,18 @@ def apply_params( problem_size, self.tile_dims, f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", - configuration, + compilation_info, ), "// ", ) modified += apply_configuration( template, - configuration, + compilation_info, ) embeddable = indent( self.get_transform_function_batch_matmul( - problem_size, self.tile_dims, f"match_op", configuration + problem_size, self.tile_dims, f"match_op", compilation_info ), " ", ) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 7f104bcd9..0428ab7d2 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -65,9 +65,8 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [16, 16, 1], 16, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) problem_size = common.ProblemSize( @@ -77,7 +76,9 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: common.ShapedType([M, N], tuner_ctx.type.f32), common.DispatchKind.mmt, ) - tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) + tf_mlir = candidate_gen.MmtTuner().apply_params( + problem_size, mlir_template, compilation_info + ) modified = tf_mlir.modified embeddable = tf_mlir.embeddable @@ -134,9 +135,8 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [256, 1, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) problem_size = common.ProblemSize( @@ -147,7 +147,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.conv, ) tf_mlir = candidate_gen.ConvTuner().apply_params( - problem_size, mlir_template, config + problem_size, mlir_template, compilation_info ) modified = tf_mlir.modified @@ -208,13 +208,12 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [256, 1, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, config + problem_size, mlir_template, compilation_info ) new_mlir = tf_mlir.modified @@ -268,13 +267,12 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [128, 2, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, config + problem_size, mlir_template, compilation_info ) modified = tf_mlir.modified @@ -331,13 +329,12 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [128, 2, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, config + problem_size, mlir_template, compilation_info ) modified = tf_mlir.modified @@ -392,13 +389,12 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [128, 2, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, config + problem_size, mlir_template, compilation_info ) modified = tf_mlir.modified @@ -477,14 +473,13 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [128, 2, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) tf_mlir = candidate_gen.ContractionTuner( "mk", "nk", "mnk" - ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, config) + ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, compilation_info) modified = tf_mlir.modified embeddable = tf_mlir.embeddable diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index c683c5bdc..5c79bd8dd 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -111,12 +111,6 @@ def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: return list(filter(is_comptible, mma_intrinsics)) -@dataclass -class Configuration: - translation_info: iree_codegen.TranslationInfoAttr - lowering_config: iree_gpu.LoweringConfigAttr - - # The key name for GPUPipelineOptionsAttr in the translation info config dictionary. GPU_PIPELINE_OPTIONS_KEY = "gpu_pipeline_options" # The key name for llvm_func_attrs attribute in the translation info config dictionary. diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index af1e1bf9a..6157bb355 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -93,12 +93,11 @@ def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [16, 16, 1], 32, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) config1_str: str = str( - config.translation_info.configuration[common.LLVM_FUNC_ATTRS_KEY] + compilation_info.translation_info.configuration[common.LLVM_FUNC_ATTRS_KEY] ) assert config1_str == '{"amdgpu-waves-per-eu" = "2"}' @@ -107,11 +106,10 @@ def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [16, 16, 1], 32, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) - config2_str: str = str(config.translation_info.configuration) + config2_str: str = str(compilation_info.translation_info.configuration) assert ( config2_str == '{gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}}' @@ -225,10 +223,9 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [16, 16, 1], 32, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) - assert config.lowering_config.mma_kind is None - assert config.lowering_config.subgroup_count_mn == (1, 1) + assert compilation_info.lowering_config.mma_kind is None + assert compilation_info.lowering_config.subgroup_count_mn == (1, 1) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 8ba202310..797c83534 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -178,7 +178,7 @@ def generate_solutions( problem_size: ProblemSize, num_subgrups: int, mma_intrinsics: list[iree_gpu.MMAIntrinsic], -) -> Iterator[Configuration]: +) -> Iterator[iree_codegen.CompilationInfoAttr]: M, N, K = problem_size.MNK tuner_ctx.logger.info(f"{M},{N},{K}") m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") @@ -258,7 +258,9 @@ def generate_solutions( lookup(subgroup_size), config_dict, ) - config = Configuration(translation_info, lowering_config) + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info + ) solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) i += 1 - yield config + yield compilation_info diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index cc63c89a3..fe95c52a6 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -21,7 +21,7 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: def get_contract_workgroup_sizes( - configuration: Configuration, tile_dims: str + configuration: iree_codegen.CompilationInfoAttr, tile_dims: str ) -> list[int]: m, n, _k = configuration.lowering_config.workgroup_tile_sizes @@ -38,7 +38,7 @@ def get_contract_workgroup_sizes( def get_contract_reduction_sizes( - configuration: Configuration, tile_dims: str + configuration: iree_codegen.CompilationInfoAttr, tile_dims: str ) -> list[int]: _m, _n, k = configuration.lowering_config.reduction_tile_sizes reduction_size = [0] * len(tile_dims) diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index a63576808..9f4afbb19 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -59,11 +59,10 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [], 0, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) - lowering_config = config.lowering_config + lowering_config = compilation_info.lowering_config assert lowering_config.workgroup_tile_sizes == [128, 320, 0] assert lowering_config.reduction_tile_sizes == [0, 0, 32] @@ -87,12 +86,27 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [256, 1, 1], 64, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info ) - assert config.lowering_config.workgroup_tile_sizes == [1, 1, 464, 320, 1, 1, 0] - assert config.lowering_config.reduction_tile_sizes == [0, 0, 0, 0, 0, 0, 16] + assert compilation_info.lowering_config.workgroup_tile_sizes == [ + 1, + 1, + 464, + 320, + 1, + 1, + 0, + ] + assert compilation_info.lowering_config.reduction_tile_sizes == [ + 0, + 0, + 0, + 0, + 0, + 0, + 16, + ] def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: @@ -114,18 +128,49 @@ def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: translation_info = iree_codegen.TranslationInfoAttr.get( pipeline_attr, None, [16, 16, 1], 32, config_dict ) - config = common.Configuration( - translation_info=translation_info, - lowering_config=lowering_config, - ) - assert dispatch_parser.get_contract_workgroup_sizes(config, "mnk") == [4, 8, 0] - assert dispatch_parser.get_contract_reduction_sizes(config, "mnk") == [0, 0, 16] - assert dispatch_parser.get_contract_workgroup_sizes(config, "nmk") == [8, 4, 0] - assert dispatch_parser.get_contract_reduction_sizes(config, "nmk") == [0, 0, 16] - assert dispatch_parser.get_contract_workgroup_sizes(config, "knm") == [0, 8, 4] - assert dispatch_parser.get_contract_reduction_sizes(config, "knm") == [16, 0, 0] - assert dispatch_parser.get_contract_workgroup_sizes(config, "kkk") == [0, 0, 0] - assert dispatch_parser.get_contract_reduction_sizes(config, "kkk") == [16, 16, 16] + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info + ) + assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "mnk") == [ + 4, + 8, + 0, + ] + assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "mnk") == [ + 0, + 0, + 16, + ] + assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "nmk") == [ + 8, + 4, + 0, + ] + assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "nmk") == [ + 0, + 0, + 16, + ] + assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "knm") == [ + 0, + 8, + 4, + ] + assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "knm") == [ + 16, + 0, + 0, + ] + assert dispatch_parser.get_contract_workgroup_sizes(compilation_info, "kkk") == [ + 0, + 0, + 0, + ] + assert dispatch_parser.get_contract_reduction_sizes(compilation_info, "kkk") == [ + 16, + 16, + 16, + ] def test_get_shapes_mmt(tuner_ctx: common.TunerContext) -> None: diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 3aa932dc4..3c195520c 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -64,7 +64,7 @@ class CandidateTracker: candidate_id: int dispatch_mlir_path: Optional[Path] = None dispatch_config_path: Optional[Path] = None - configuration: Optional[candidate_gen.Configuration] = None + configuration: Optional[candidate_gen.iree_codegen.CompilationInfoAttr] = None compilation_successful: Optional[bool] = None compiled_dispatch_path: Optional[Path] = None compiled_dispatch_hash: Optional[str] = None From f7d2681124b65588090ecebf1b778d5650ff982d Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Thu, 12 Dec 2024 17:44:58 -0600 Subject: [PATCH 06/39] Update user docs for running `llm server` + upgrade `gguf` to `0.11.0` (#676) # Description Did a pass through and made updates + fixes to the user docs for `e2e_llama8b_mi300x.md`. 1. Update install instructions for `shark-ai` 2. Update nightly install instructions for `shortfin` and `sharktank` 3. Update paths for model artifacts to ensure they work with `llama3.1-8b-fp16-instruct` 4. Remove steps to `write edited config`. No longer needed after #487 Added back `sentencepiece` as a requirement for `sharktank`. Not having it caused `export_paged_llm_v1` to break when installing nightly: ```text ModuleNotFoundError: No module named 'sentencepiece' ``` This was obfuscated when building from source, because `shortfin` includes `sentencepiece` in `requirements-tests.txt`. --- docs/shortfin/llm/user/e2e_llama8b_mi300x.md | 82 ++++++-------------- sharktank/requirements.txt | 5 +- 2 files changed, 23 insertions(+), 64 deletions(-) diff --git a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md index 4a8423bc8..313a8086c 100644 --- a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md +++ b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md @@ -22,32 +22,28 @@ python -m venv --prompt shark-ai .venv source .venv/bin/activate ``` -### Install `shark-ai` +## Install stable shark-ai packages -You can install either the `latest stable` version of `shark-ai` -or the `nightly` version: - -#### Stable + ```bash -pip install shark-ai +pip install shark-ai[apps] sharktank ``` -#### Nightly - -```bash -pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels -pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels -``` +### Nightly packages -#### Install dataclasses-json +To install nightly packages: - + ```bash -pip install dataclasses-json +pip install shark-ai[apps] sharktank \ + --pre --find-links https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels ``` +See also the +[instructions here](https://github.com/nod-ai/shark-ai/blob/main/docs/nightly_releases.md). + ### Define a directory for export files Create a new directory for us to export files like @@ -78,8 +74,8 @@ This example uses the `llama8b_f16.gguf` and `tokenizer.json` files that were downloaded in the previous step. ```bash -export MODEL_PARAMS_PATH=$EXPORT_DIR/llama3.1-8b/llama8b_f16.gguf -export TOKENIZER_PATH=$EXPORT_DIR/llama3.1-8b/tokenizer.json +export MODEL_PARAMS_PATH=$EXPORT_DIR/meta-llama-3.1-8b-instruct.f16.gguf +export TOKENIZER_PATH=$EXPORT_DIR/tokenizer.json ``` #### General env vars @@ -91,8 +87,6 @@ The following env vars can be copy + pasted directly: export MLIR_PATH=$EXPORT_DIR/model.mlir # Path to export config.json file export OUTPUT_CONFIG_PATH=$EXPORT_DIR/config.json -# Path to export edited_config.json file -export EDITED_CONFIG_PATH=$EXPORT_DIR/edited_config.json # Path to export model.vmfb file export VMFB_PATH=$EXPORT_DIR/model.vmfb # Batch size for kvcache @@ -108,7 +102,7 @@ to export our model to `.mlir` format. ```bash python -m sharktank.examples.export_paged_llm_v1 \ - --irpa-file=$MODEL_PARAMS_PATH \ + --gguf-file=$MODEL_PARAMS_PATH \ --output-mlir=$MLIR_PATH \ --output-config=$OUTPUT_CONFIG_PATH \ --bs=$BS @@ -137,37 +131,6 @@ iree-compile $MLIR_PATH \ -o $VMFB_PATH ``` -## Write an edited config - -We need to write a config for our model with a slightly edited structure -to run with shortfin. This will work for the example in our docs. -You may need to modify some of the parameters for a specific model. - -### Write edited config - -```bash -cat > $EDITED_CONFIG_PATH << EOF -{ - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 131072, - "attn_head_count": 8, - "attn_head_dim": 128, - "prefill_batch_sizes": [ - $BS - ], - "decode_batch_sizes": [ - $BS - ], - "transformer_block_count": 32, - "paged_kv_cache": { - "block_seq_stride": 16, - "device_block_count": 256 - } -} -EOF -``` - ## Running the `shortfin` LLM server We should now have all of the files that we need to run the shortfin LLM server. @@ -178,15 +141,14 @@ Verify that you have the following in your specified directory ($EXPORT_DIR): ls $EXPORT_DIR ``` -- edited_config.json +- config.json +- meta-llama-3.1-8b-instruct.f16.gguf +- model.mlir - model.vmfb +- tokenizer_config.json +- tokenizer.json -### Launch server: - - +### Launch server #### Run the shortfin server @@ -209,7 +171,7 @@ Run the following command to launch the Shortfin LLM Server in the background: ```bash python -m shortfin_apps.llm.server \ --tokenizer_json=$TOKENIZER_PATH \ - --model_config=$EDITED_CONFIG_PATH \ + --model_config=$OUTPUT_CONFIG_PATH \ --vmfb=$VMFB_PATH \ --parameters=$MODEL_PARAMS_PATH \ --device=hip > shortfin_llm_server.log 2>&1 & @@ -252,7 +214,7 @@ port = 8000 # Change if running on a different port generate_url = f"http://localhost:{port}/generate" def generation_request(): - payload = {"text": "What is the capital of the United States?", "sampling_params": {"max_completion_tokens": 50}} + payload = {"text": "Name the capital of the United States.", "sampling_params": {"max_completion_tokens": 50}} try: resp = requests.post(generate_url, json=payload) resp.raise_for_status() # Raises an HTTPError for bad responses diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt index e7181284a..90cbedded 100644 --- a/sharktank/requirements.txt +++ b/sharktank/requirements.txt @@ -1,12 +1,9 @@ iree-turbine # Runtime deps. -gguf==0.10.0 +gguf>=0.11.0 numpy<2.0 -# Needed for newer gguf versions (TODO: remove when gguf package includes this) -# sentencepiece>=0.1.98,<=0.2.0 - # Model deps. huggingface-hub==0.22.2 transformers==4.40.0 From 3d8cad883c2251579f0dd9862722227a1acc9d73 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Thu, 12 Dec 2024 21:16:44 -0500 Subject: [PATCH 07/39] A basic working version of the flux model (#663) This version of the flux model should work, as it directly modifies the reference implementation, but could really use some refactoring, especially to reduce code duplication --------- Co-authored-by: Boian Petkantchin --- sharktank/sharktank/models/flux/flux.py | 251 +++++++++++++++++++++ sharktank/tests/layers/mmdit_test.py | 20 +- sharktank/tests/models/flux/flux_test.py | 271 +++++++++++++++++++++++ 3 files changed, 538 insertions(+), 4 deletions(-) create mode 100644 sharktank/sharktank/models/flux/flux.py create mode 100644 sharktank/tests/models/flux/flux_test.py diff --git a/sharktank/sharktank/models/flux/flux.py b/sharktank/sharktank/models/flux/flux.py new file mode 100644 index 000000000..ac63f47a0 --- /dev/null +++ b/sharktank/sharktank/models/flux/flux.py @@ -0,0 +1,251 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# Copyright 2024 Black Forest Labs. Inc. and Flux Authors +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Model adapted from black-forest-labs' flux implementation +https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py +""" + +import math +from dataclasses import dataclass +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...layers import * +from ...types import * +from ...utils.create_cache import * +from ... import ops + +__all__ = [ + "FluxModelV1", +] + +################################################################################ +# Models +################################################################################ + + +@dataclass +class FluxParams: + in_channels: int + out_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class FluxModelV1(ThetaLayer): + """FluxModel adapted from Black Forest Lab's implementation.""" + + def __init__(self, theta: Theta, params: FluxParams): + super().__init__( + theta, + ) + + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + self.add_module("img_in", LinearLayer(theta("img_in"))) + # TODO: Refactor this pattern to an MLPEmbedder like src implementatio + self.add_module("time_in_0", LinearLayer(theta("time_in.0"))) + self.add_module("time_in_1", LinearLayer(theta("time_in.1"))) + self.add_module("vector_in_0", LinearLayer(theta("vector_in.0"))) + self.add_module("vector_in_1", LinearLayer(theta("vector_in.1"))) + self.guidance = False + if params.guidance_embed: + self.guidance = True + self.add_module("guidance_in_0", LinearLayer(theta("guidance_in.0"))) + self.add_module("guidance_in_1", LinearLayer(theta("guidance_in.1"))) + self.add_module("txt_in", LinearLayer(theta("txt_in"))) + + self.double_blocks = nn.ModuleList( + [ + MMDITDoubleBlock( + theta("double_blocks", i), + self.num_heads, + ) + for i in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + MMDITSingleBlock( + theta("single_blocks", i), + self.num_heads, + ) + for i in range(params.depth_single_blocks) + ] + ) + + self.add_module( + "last_layer", + LastLayer(theta("last_layer")), + ) + + def forward( + self, + img: AnyTensor, + img_ids: AnyTensor, + txt: AnyTensor, + txt_ids: AnyTensor, + timesteps: AnyTensor, + y: AnyTensor, + guidance: AnyTensor | None = None, + ) -> AnyTensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + time_in_0 = self.time_in_0(timestep_embedding(timesteps, 256)) + time_in_silu = ops.elementwise(F.silu, time_in_0) + vec = self.time_in_1(time_in_silu) + if self.guidance: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + guidance_inp = timestep_embedding(guidance, 256) + guidance0 = self.guidance_in0(guidance_inp) + guidance_silu = ops.elementwise(F.silu, guidance0) + guidance_out = self.guidance_in1(guidance_silu) + vec = vec + self.guidance_in(guidance_out) + vector_in_0 = self.vector_in_0(y) + vector_in_silu = ops.elementwise(F.silu, vector_in_0) + vector_in_1 = self.vector_in_1(vector_in_silu) + vec = vec + vector_in_1 + + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.last_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img + + +################################################################################ +# Layers +################################################################################ + + +# TODO: Refactor these functions to other files. Rope can probably be merged with +# our rotary embedding layer, some of these functions are shared with layers/mmdit.py +def timestep_embedding( + t: AnyTensor, dim, max_period=10000, time_factor: float = 1000.0 +): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +def layer_norm(inp): + weight = torch.ones(inp.shape) + bias = torch.zeros(inp.shape) + return ops.layer_norm(inp, weight, bias, eps=1e-6) + + +def qk_norm(q, k, v, rms_q, rms_k): + return rms_q(q).to(v), rms_k(k).to(v) + + +def rope(pos: AnyTensor, dim: int, theta: int) -> AnyTensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack( + [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 + ) + # out = out.view(out.shape[0], out.shape[1], out.shape[2], out.shape[3], 2, 2) + out = out.view(out.shape[0], out.shape[1], out.shape[2], 2, 2) + return out.float() + + +class EmbedND(torch.nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: AnyTensor) -> AnyTensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +class LastLayer(ThetaLayer): + def __init__( + self, + theta: Theta, + ): + super().__init__(theta) + self.add_module("outlinear", LinearLayer(theta("outlinear"))) + self.add_module("ada_linear", LinearLayer(theta("ada_linear"))) + + def forward(self, x: AnyTensor, vec: AnyTensor) -> AnyTensor: + silu = ops.elementwise(F.silu, vec) + lin = self.ada_linear(silu) + shift, scale = lin.chunk(2, dim=1) + x = (1 + scale[:, None, :]) * layer_norm(x) + shift[:, None, :] + x = self.outlinear(x) + return x diff --git a/sharktank/tests/layers/mmdit_test.py b/sharktank/tests/layers/mmdit_test.py index d265b33d8..a90f5dc00 100644 --- a/sharktank/tests/layers/mmdit_test.py +++ b/sharktank/tests/layers/mmdit_test.py @@ -17,16 +17,17 @@ MMDITDoubleBlock, MMDITSingleBlock, ) -import sharktank.ops as ops from sharktank.layers.testing import ( make_mmdit_double_block_random_theta, make_mmdit_single_block_random_theta, ) -from sharktank.types.tensors import DefaultPrimitiveTensor +from sharktank.utils.testing import TempDirTestBase +from sharktank.types import Dataset, Theta -class MMDITTest(unittest.TestCase): +class MMDITTest(TempDirTestBase): def setUp(self): + super().setUp() torch.manual_seed(12345) self.hidden_size = 3072 self.num_heads = 24 @@ -35,6 +36,7 @@ def setUp(self): def testDoubleExport(self): theta = make_mmdit_double_block_random_theta() + theta = self.save_load_theta(theta) mmdit = MMDITDoubleBlock( theta=theta, num_heads=self.num_heads, @@ -58,6 +60,7 @@ def _(model, img, txt, vec, rot) -> torch.Tensor: def testSingleExport(self): theta = make_mmdit_single_block_random_theta() + theta = self.save_load_theta(theta) mmdit = MMDITSingleBlock( theta=theta, num_heads=self.num_heads, @@ -73,10 +76,19 @@ def testSingleExport(self): def _(model, inp, vec, rot) -> torch.Tensor: return model.forward(inp, vec, rot) - output = aot.export(fxb) + output = aot.export(fxb, import_symbolic_shape_expressions=True) output.verify() asm = str(output.mlir_module) + def save_load_theta(self, theta: Theta): + # Roundtrip to disk to avoid treating parameters as constants that would appear + # in the MLIR. + theta.rename_tensors_to_paths() + dataset = Dataset(root_theta=theta, properties={}) + file_path = self._temp_dir / "parameters.irpa" + dataset.save(file_path) + return Dataset.load(file_path).root_theta + if __name__ == "__main__": unittest.main() diff --git a/sharktank/tests/models/flux/flux_test.py b/sharktank/tests/models/flux/flux_test.py new file mode 100644 index 000000000..ea80c7b42 --- /dev/null +++ b/sharktank/tests/models/flux/flux_test.py @@ -0,0 +1,271 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import unittest + +import torch + +from iree.turbine import aot +from sharktank.models.flux.flux import ( + FluxModelV1, + FluxParams, +) +import sharktank.ops as ops +from sharktank.layers.testing import ( + make_rand_torch, +) +from sharktank.types.tensors import DefaultPrimitiveTensor +from sharktank.types.theta import Dataset, Theta +from sharktank.utils.testing import TempDirTestBase + + +# TODO: Refactor this to a function that generates random toy weights, possibly +# to another file +dtype = torch.float32 +in_channels = 64 +in_channels2 = 128 +hidden_size = 3072 +mlp_ratio = 4.0 +mlp_hidden_size = int((mlp_ratio - 1) * hidden_size) +mlp_hidden_size2 = int(mlp_ratio * hidden_size) +mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size) +mlp_hidden_size4 = int((mlp_ratio + 1) * hidden_size) +mlp_hidden_size5 = int((2 * mlp_ratio - 1) * hidden_size) +context_in_dim = 4096 +time_dim = 256 +vec_dim = 768 +patch_size = 1 +out_channels = 64 + + +def make_random_theta(): + return Theta( + { + "img_in.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, in_channels), dtype=dtype) + ), + "img_in.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "txt_in.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, context_in_dim), dtype=dtype) + ), + "txt_in.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "time_in.0.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, time_dim), dtype=dtype) + ), + "time_in.0.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "time_in.1.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "time_in.1.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "vector_in.0.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, vec_dim), dtype=dtype) + ), + "vector_in.0.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "vector_in.1.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "vector_in.1.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "double_blocks.0.img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.img_attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "double_blocks.0.img_attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.img_attn.qkv.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) + ), + "double_blocks.0.img_attn.qkv.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.img_mlp.0.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.img_mlp.0.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype) + ), + "double_blocks.0.img_mlp.2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size), dtype=dtype) + ), + "double_blocks.0.img_mlp.2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.img_mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3,), dtype=dtype) + ), + "double_blocks.0.img_mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.txt_attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "double_blocks.0.txt_attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_attn.qkv.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) + ), + "double_blocks.0.txt_attn.qkv.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_mlp.0.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.txt_mlp.0.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_mlp.2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_mlp.2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.txt_mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3,), dtype=dtype) + ), + "double_blocks.0.txt_mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) + ), + "single_blocks.0.attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "single_blocks.0.attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "single_blocks.0.attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "single_blocks.0.attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "single_blocks.0.linear1.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size5,), dtype=dtype) + ), + "single_blocks.0.linear1.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size5, hidden_size), dtype=dtype) + ), + "single_blocks.0.linear2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size), dtype=dtype) + ), + "single_blocks.0.linear2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, mlp_hidden_size4), dtype=dtype) + ), + "single_blocks.0.mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) + ), + "single_blocks.0.mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) + ), + "last_layer.outlinear.weight": DefaultPrimitiveTensor( # + data=make_rand_torch( + (patch_size * patch_size * out_channels, hidden_size), dtype=dtype + ) + ), + "last_layer.outlinear.bias": DefaultPrimitiveTensor( # + data=make_rand_torch( + (patch_size * patch_size * out_channels,), dtype=dtype + ) + ), + "last_layer.ada_linear.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size * 2, hidden_size), dtype=dtype) + ), + "last_layer.ada_linear.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size * 2,), dtype=dtype) + ), + } + ) + + +class FluxTest(TempDirTestBase): + def setUp(self): + super().setUp() + torch.manual_seed(12345) + self.hidden_size = 3072 + self.num_heads = 24 + self.batch_size = 5 + + def testExport(self): + params = FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=1, + depth_single_blocks=1, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ) + theta = make_random_theta() + theta = self.save_load_theta(theta) + flux = FluxModelV1( + theta=theta, + params=params, + ) + + img = torch.rand([self.batch_size, 1024, 64]) + img_ids = torch.rand([self.batch_size, 1024, 3]) + txt = torch.rand([self.batch_size, 512, 4096]) + txt_ids = torch.rand([self.batch_size, 512, 3]) + timesteps = torch.rand([self.batch_size]) + y = torch.rand([self.batch_size, 768]) + + flux.forward(img, img_ids, txt, txt_ids, timesteps, y) + fxb = aot.FxProgramsBuilder(flux) + + @fxb.export_program( + name="flux", args=(img, img_ids, txt, txt_ids, timesteps, y), strict=False + ) + def _(model, img, img_ids, txt, txt_ids, timesteps, y) -> torch.Tensor: + return model.forward(img, img_ids, txt, txt_ids, timesteps, y) + + output = aot.export(fxb) + output.verify() + asm = str(output.mlir_module) + + def save_load_theta(self, theta: Theta): + # Roundtrip to disk to avoid treating parameters as constants that would appear + # in the MLIR. + theta.rename_tensors_to_paths() + dataset = Dataset(root_theta=theta, properties={}) + file_path = self._temp_dir / "parameters.irpa" + dataset.save(file_path) + return Dataset.load(file_path).root_theta + + +if __name__ == "__main__": + unittest.main() From 77ca02fcba07f100ee5f56a3030d3ffc47501031 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 12 Dec 2024 21:53:05 -0800 Subject: [PATCH 08/39] Enable flash attention by default (#690) --- sharktank/sharktank/layers/configs/llm_configs.py | 2 +- sharktank/sharktank/layers/paged_llama_attention_block.py | 4 +--- sharktank/sharktank/utils/cli.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 8a443e6ca..88f5c344c 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -167,7 +167,7 @@ class LlamaModelConfig: tensor_parallelism_size: int = 1 # Which attention kernel to use. - attention_kernel: str = "decomposed" + attention_kernel: str = "torch" # Indicates if running with HuggingFace implementation and ensures # numerical equivalency to HuggingFace's LLaMa if true (by modifying diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 22647bf49..6bd33c93f 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -216,14 +216,12 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: attn_weights, values ) # (bs, heads, slen, head_dim) else: - is_causal = True - attention_mask = None attn_output = ops.scaled_dot_product_attention( q=xq, # [bs, ..., sl, dim] k=keys, # [bs, ..., sl, dim] v=values, # [bs, ..., sl, dim] a=attention_mask, # [bs, ..., sl, sl] - is_causal=is_causal, # assumes causal masking when true + is_causal=False, # assumes causal masking when true scale=None, # defaults to 1/sqrt(dim) ) diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 99917c2d3..9fefeb66f 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -66,7 +66,7 @@ def add_model_options(parser: argparse.ArgumentParser): parser.add_argument( "--attention-kernel", type=str, - default="decomposed", + default="torch", choices=["decomposed", "torch"], ) parser.add_argument( From ec1424e2efc2deac115be56f0266eceb8a67644b Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Thu, 12 Dec 2024 22:17:21 -0800 Subject: [PATCH 09/39] Add block_seq_stride flag (#692) Add `block_seq_stride` flag --------- Co-authored-by: Rob Suderman --- sharktank/sharktank/examples/export_paged_llm_v1.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 900c1a9ae..ad297bcce 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -45,6 +45,12 @@ def main(): type=lambda arg: [int(bs) for bs in arg.split(",")], default="4", ) + parser.add_argument( + "--block-seq-stride", + help="Block sequence stride for paged KV cache, must divide evenly into the context length", + type=int, + default="16", + ) parser.add_argument( "--verbose", help="Include verbose logging", @@ -76,6 +82,7 @@ def main(): static_tables=False, # Rely on the compiler for hoisting tables. kv_cache_type="direct" if args.bs == [1] else "paged", attention_kernel=args.attention_kernel, + block_seq_stride=args.block_seq_stride, ) llama_config.fake_quant = args.fake_quant From bcecd43c9a0e939c0438191897b7000b8139d312 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 13 Dec 2024 11:15:59 +0100 Subject: [PATCH 10/39] [sharktank] Mark tests as expected to fail (#686) Marks tests that fail with `torch==2.{4.0,4.1,5.0,5.1}+cpu` as expected to fail: * `test_sharded_conv2d_with_iree` * `test_sharded_resnet_block_with_iree` * `testExportNondecomposed` * `testExportWithArgumentDeviceAffinities` --- sharktank/tests/export_test.py | 6 ++++++ sharktank/tests/layers/paged_llama_attention_block_test.py | 6 ++++++ sharktank/tests/layers/sharded_conv2d_with_iree_test.py | 5 +++++ .../models/punet/sharded_resnet_block_with_iree_test.py | 3 +++ 4 files changed, 20 insertions(+) diff --git a/sharktank/tests/export_test.py b/sharktank/tests/export_test.py index 20b7de734..92992121b 100644 --- a/sharktank/tests/export_test.py +++ b/sharktank/tests/export_test.py @@ -4,6 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest + from sharktank.types import ( ReplicatedTensor, SplitPrimitiveTensor, @@ -70,6 +72,10 @@ def testGetFlatArgumentDeviceAffinities(self): } assert_dicts_equal(affinities, expected_affinities) + @pytest.mark.xfail( + torch.__version__ >= (2, 4), + reason="https://github.com/nod-ai/shark-ai/issues/685", + ) def testExportWithArgumentDeviceAffinities(self): args = (ReplicatedTensor(ts=[torch.tensor([1])]), torch.tensor([[2]])) diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py index 63251c5a9..d5cb6863d 100644 --- a/sharktank/tests/layers/paged_llama_attention_block_test.py +++ b/sharktank/tests/layers/paged_llama_attention_block_test.py @@ -4,6 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest + import logging logging.basicConfig(level=logging.DEBUG) @@ -118,6 +120,10 @@ def forward(self, h, seq_block_ids, cache_state): asm = str(output.mlir_module) self.assertNotIn("scaled_dot_product_attention", asm) + @pytest.mark.xfail( + torch.__version__ >= (2, 4), + reason="https://github.com/nod-ai/shark-ai/issues/684", + ) def testExportNondecomposed(self): dtype = torch.float32 diff --git a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py index 9b29e5761..f0153e25b 100644 --- a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py +++ b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py @@ -6,6 +6,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest + from pathlib import Path import tempfile import torch @@ -181,6 +183,9 @@ def run_test_sharded_conv2d_with_iree( ) +@pytest.mark.xfail( + torch.__version__ >= (2, 5), reason="https://github.com/nod-ai/shark-ai/issues/682" +) def test_sharded_conv2d_with_iree( mlir_path: Optional[Path], module_path: Optional[Path], diff --git a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py index 581584369..c24dc149e 100644 --- a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py +++ b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py @@ -228,6 +228,9 @@ def run_test_sharded_resnet_block_with_iree( strict=True, raises=AssertionError, ) +@pytest.mark.xfail( + torch.__version__ >= (2, 5), reason="https://github.com/nod-ai/shark-ai/issues/683" +) def test_sharded_resnet_block_with_iree( mlir_path: Optional[Path], module_path: Optional[Path], From 0bec8f1737242e03a21da1f7e4c975038ed16cd0 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 13 Dec 2024 12:10:41 +0100 Subject: [PATCH 11/39] [sharktank] Use reshape instead of view (#681) Fixes `RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.` by using `reshape` which occurs when running `tests/layers/mmdit_test.py` with Python 3.11.10 and `torch==2.5.1+cpu`. --- sharktank/sharktank/layers/mmdit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sharktank/sharktank/layers/mmdit.py b/sharktank/sharktank/layers/mmdit.py index 0c970ab35..1557883ae 100644 --- a/sharktank/sharktank/layers/mmdit.py +++ b/sharktank/sharktank/layers/mmdit.py @@ -41,7 +41,7 @@ def attention(q, k, v, pe): q=q, k=k, v=v, a=None, is_causal=True, scale=None ) x = ops.permute(x, (0, 2, 1, 3)) - x = x.view(x.shape[0], x.shape[1], -1) + x = x.reshape(x.shape[0], x.shape[1], -1) return x From 332baa32f6ec9022e9c570a4b244e28731b4bda9 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 13 Dec 2024 19:36:18 +0100 Subject: [PATCH 12/39] [sharktank] Extend to test with Python 3.12 (#693) Progress on https://github.com/nod-ai/shark-ai/issues/357 Extends to test with Python 3.12 and newer `torch` versions on Linux. For `torch.compile` support on Python 3.12, torch ge 2.4.0 is required. Testing with torch 2.3.0 is therefore skipped when running Python 3.12. To limit the number of jobs, tests are currently only expanded for Linux, whereas Windows (and unpinning NumPy) should be handled in a follow up. The error logged for `testExportNondecomposed` when failing with torch 2.5.1 slows down the CI by more than 20 minutes, therefore the test is now skipped instead of marked xfail. --- .github/workflows/ci-sharktank.yml | 23 +++++++++++++------ .../paged_llama_attention_block_test.py | 4 ++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index f169eb2b2..5ee2fbcae 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -23,11 +23,20 @@ concurrency: jobs: test: - name: "Unit Tests and Type Checking" + name: "Unit Tests (${{ matrix.os }}, ${{ matrix.python-version }}, ${{ matrix.torch-version }})" strategy: matrix: - version: [3.11] - os: [ubuntu-24.04, windows-2022] + python-version: ["3.11", "3.12"] + torch-version: ["2.3.0", "2.4.1", "2.5.1"] + os: [ubuntu-24.04] + include: + - os: windows-2022 + python-version: "3.11" + torch-version: "2.3.0" + exclude: + - python-version: "3.12" + # `torch.compile` requires torch>=2.4.0 for Python 3.12+ + torch-version: "2.3.0" fail-fast: false runs-on: ${{matrix.os}} defaults: @@ -42,7 +51,7 @@ jobs: id: setup_python uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: - python-version: ${{matrix.version}} + python-version: ${{matrix.python-version}} - name: Cache Pip Packages uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 @@ -57,7 +66,7 @@ jobs: # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile --pre --index-url https://download.pytorch.org/whl/test/cpu torch==${{matrix.torch-version}}+cpu pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ # Install nightly IREE packages. @@ -77,7 +86,7 @@ jobs: name: "Data-dependent Tests" strategy: matrix: - version: [3.11] + python-version: [3.11] runs-on: [llama-mi300x-3] fail-fast: false runs-on: ${{matrix.runs-on}} @@ -94,7 +103,7 @@ jobs: id: setup_python uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: - python-version: ${{matrix.version}} + python-version: ${{matrix.python-version}} - name: Create Python venv run: python -m venv ${VENV_DIR} diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py index d5cb6863d..bbb52f235 100644 --- a/sharktank/tests/layers/paged_llama_attention_block_test.py +++ b/sharktank/tests/layers/paged_llama_attention_block_test.py @@ -124,6 +124,10 @@ def forward(self, h, seq_block_ids, cache_state): torch.__version__ >= (2, 4), reason="https://github.com/nod-ai/shark-ai/issues/684", ) + @pytest.mark.skipif( + torch.__version__ >= (2, 5), + reason="https://github.com/nod-ai/shark-ai/issues/684, error slows down CI", + ) def testExportNondecomposed(self): dtype = torch.float32 From f4876fb21c84dbd8463a085a9d6a5ab9e9326e64 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 16 Dec 2024 18:06:39 +0100 Subject: [PATCH 13/39] Explicitly install nightlies (#699) As `iree-turbine` is specified in `sharktank/requirements.txt`, the dependency is already fulfilled and nightlies don't get installed. One option would be to enforce to `--upgrade` but instead the packages get not installed before stable releases get pulled in via sharktank's requirements file. For workflow running on non-GH hosted runners, `--upgrade` is passed additionaly to avoid using an eventually cached version on the self-hosted runners. --- .github/workflows/ci-llama-large-tests.yaml | 4 +++- .github/workflows/ci-llama-quick-tests.yaml | 3 ++- .github/workflows/ci-sglang-integration-tests.yml | 4 +++- .github/workflows/ci-shark-ai.yml | 4 +++- .github/workflows/ci-sharktank.yml | 14 ++++++++++---- .github/workflows/ci_eval.yaml | 9 ++++++--- .github/workflows/ci_eval_short.yaml | 3 ++- 7 files changed, 29 insertions(+), 12 deletions(-) diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml index 379a5bac3..2eb8e6496 100644 --- a/.github/workflows/ci-llama-large-tests.yaml +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -53,11 +53,11 @@ jobs: run: | source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ # Install nightly IREE packages. # We could also pin to a known working or stable version. @@ -66,6 +66,8 @@ jobs: iree-base-runtime \ iree-turbine + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + pip freeze - name: Run llama tests diff --git a/.github/workflows/ci-llama-quick-tests.yaml b/.github/workflows/ci-llama-quick-tests.yaml index 418dbea2e..7ad153924 100644 --- a/.github/workflows/ci-llama-quick-tests.yaml +++ b/.github/workflows/ci-llama-quick-tests.yaml @@ -53,11 +53,11 @@ jobs: run: | source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ # Install nightly IREE packages. # We could also pin to a known working or stable version. @@ -66,6 +66,7 @@ jobs: iree-base-runtime \ iree-turbine + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ pip freeze - name: Run llama 8b f16 decomposed test diff --git a/.github/workflows/ci-sglang-integration-tests.yml b/.github/workflows/ci-sglang-integration-tests.yml index c3627226e..4eaae8dc4 100644 --- a/.github/workflows/ci-sglang-integration-tests.yml +++ b/.github/workflows/ci-sglang-integration-tests.yml @@ -50,11 +50,11 @@ jobs: run: | source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ # Use newest possible releases to be able to track commits that may # cause errors. @@ -64,6 +64,8 @@ jobs: iree-turbine \ "numpy<2.0" + pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ + # Install SGLang and sentence_transformers pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" pip install sentence_transformers diff --git a/.github/workflows/ci-shark-ai.yml b/.github/workflows/ci-shark-ai.yml index 0e8ae0c87..67d52e5a6 100644 --- a/.github/workflows/ci-shark-ai.yml +++ b/.github/workflows/ci-shark-ai.yml @@ -49,11 +49,11 @@ jobs: run: | source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ # Install nightly IREE packages. # We could also pin to a known working or stable version. @@ -62,6 +62,8 @@ jobs: iree-base-runtime \ iree-turbine + pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ + pip freeze - name: Run LLM Integration Tests diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 5ee2fbcae..c08f2f412 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -63,11 +63,11 @@ jobs: - name: Install pip deps run: | python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile --pre --index-url https://download.pytorch.org/whl/test/cpu torch==${{matrix.torch-version}}+cpu - pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ # Install nightly IREE packages. # We could also pin to a known working or stable version. @@ -76,6 +76,8 @@ jobs: iree-base-runtime \ iree-turbine + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + - name: Run sharktank tests if: ${{ !cancelled() }} run: | @@ -116,7 +118,6 @@ jobs: # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ # Install nightly IREE packages. # We could also pin to a known working or stable version. @@ -125,6 +126,8 @@ jobs: iree-base-runtime \ iree-turbine + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + - name: Run tests # TODO: unify with-t5-data and with-clip-data flags into a single flag # and make it possible to run only tests that require data. @@ -162,18 +165,21 @@ jobs: - name: Install pip deps run: | python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ # Install nightly IREE packages. # We could also pin to a known working or stable version. - pip install -f https://iree.dev/pip-release-links.html --pre --upgrade \ + pip install -f https://iree.dev/pip-release-links.html --pre \ iree-base-compiler \ iree-base-runtime \ iree-turbine + + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + - name: Run punet tests run: | pytest -v sharktank/ -m punet_quick \ diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index c41709f38..3b85cb652 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -51,19 +51,20 @@ jobs: run: | source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ # Install nightly IREE packages. # We could also pin to a known working or stable version. - pip install -f https://iree.dev/pip-release-links.html --pre \ + pip install -f https://iree.dev/pip-release-links.html --pre --upgrade \ iree-base-compiler \ iree-base-runtime \ iree-turbine + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ pip freeze - name: Run perplexity test with IREE @@ -109,17 +110,19 @@ jobs: run: | source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ # Install nightly iree-turbine. # We could also pin to a known working or stable version. pip install -f https://iree.dev/pip-release-links.html --pre --upgrade \ iree-turbine + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + - name: Run perplexity test with Torch run: | source ${VENV_DIR}/bin/activate diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml index fd3b5e5d7..edaaee966 100644 --- a/.github/workflows/ci_eval_short.yaml +++ b/.github/workflows/ci_eval_short.yaml @@ -50,11 +50,11 @@ jobs: run: | source ${VENV_DIR}/bin/activate python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ # Install nightly IREE packages. # We could also pin to a known working or stable version. @@ -63,6 +63,7 @@ jobs: iree-base-runtime \ iree-turbine + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ pip freeze - name: Run perplexity test with vmfb From ff32f2506eebe503511525846226e97ecbe1818c Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 16 Dec 2024 18:14:29 +0100 Subject: [PATCH 14/39] [sharktank] Test additional version on windows (#697) --- .github/workflows/ci-sharktank.yml | 3 +++ sharktank/tests/examples/main_test.py | 5 +++++ sharktank/tests/layers/sharded_conv2d_with_iree_test.py | 4 ++++ .../models/punet/sharded_resnet_block_with_iree_test.py | 7 ++++++- 4 files changed, 18 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index c08f2f412..19faf3bdd 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -33,6 +33,9 @@ jobs: - os: windows-2022 python-version: "3.11" torch-version: "2.3.0" + - os: windows-2022 + python-version: "3.12" + torch-version: "2.4.1" exclude: - python-version: "3.12" # `torch.compile` requires torch>=2.4.0 for Python 3.12+ diff --git a/sharktank/tests/examples/main_test.py b/sharktank/tests/examples/main_test.py index fb43977df..deac0ae53 100644 --- a/sharktank/tests/examples/main_test.py +++ b/sharktank/tests/examples/main_test.py @@ -4,11 +4,16 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest +import sys import unittest from sharktank.utils.testing import MainRunnerTestBase +@pytest.mark.skipif( + sys.platform == "win32", reason="https://github.com/nod-ai/shark-ai/issues/698" +) class ShardingTests(MainRunnerTestBase): def testExportFfnNet(self): from sharktank.examples.sharding.export_ffn_net import main diff --git a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py index f0153e25b..410baddc8 100644 --- a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py +++ b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import pytest +import sys from pathlib import Path import tempfile @@ -186,6 +187,9 @@ def run_test_sharded_conv2d_with_iree( @pytest.mark.xfail( torch.__version__ >= (2, 5), reason="https://github.com/nod-ai/shark-ai/issues/682" ) +@pytest.mark.skipif( + sys.platform == "win32", reason="https://github.com/nod-ai/shark-ai/issues/698" +) def test_sharded_conv2d_with_iree( mlir_path: Optional[Path], module_path: Optional[Path], diff --git a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py index c24dc149e..7c23e1d8c 100644 --- a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py +++ b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py @@ -4,6 +4,9 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest +import sys + from pathlib import Path import tempfile @@ -19,7 +22,6 @@ import iree.runtime from typing import List, Optional import os -import pytest vm_context: iree.runtime.VmContext = None @@ -231,6 +233,9 @@ def run_test_sharded_resnet_block_with_iree( @pytest.mark.xfail( torch.__version__ >= (2, 5), reason="https://github.com/nod-ai/shark-ai/issues/683" ) +@pytest.mark.skipif( + sys.platform == "win32", reason="https://github.com/nod-ai/shark-ai/issues/698" +) def test_sharded_resnet_block_with_iree( mlir_path: Optional[Path], module_path: Optional[Path], From eb1d4c3f18e6e62412ab48bad001e4a699346e43 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 16 Dec 2024 18:28:39 +0100 Subject: [PATCH 15/39] [sharktank] Unpin NumPy (#694) `numpy<2.0` is required for Windows with `torch<2.4`. Expanding testing on Windows with Python 3.12 and torch 2.4.1. --- .github/workflows/ci-sharktank.yml | 4 ++++ sharktank/requirements.txt | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 19faf3bdd..8e41c68b8 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -63,6 +63,10 @@ jobs: path: ${{ env.PIP_CACHE_DIR }} key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + - name: Install numpy + if: ${{ matrix.os == 'windows-2022' && matrix.torch-version == '2.3.0' }} + run: pip install "numpy<2.0" + - name: Install pip deps run: | python -m pip install --no-compile --upgrade pip diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt index 90cbedded..8a2a8ea3b 100644 --- a/sharktank/requirements.txt +++ b/sharktank/requirements.txt @@ -2,7 +2,7 @@ iree-turbine # Runtime deps. gguf>=0.11.0 -numpy<2.0 +numpy # Model deps. huggingface-hub==0.22.2 From 4f542ac9181b3f054a57e3d4efdd2d049591570d Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Mon, 16 Dec 2024 13:29:48 -0500 Subject: [PATCH 16/39] [tuner] Add direct TD spec generation for candidates (#606) This PR adds direct transform dialect spec generation for candidate configurations. This is the first part of the large refactoring described in https://github.com/nod-ai/shark-ai/pull/577. The way TD specs are generated is by matching against certain types of operations, and then creating a named sequence with `transform.iree.match.cast_compatible_dag_from_root` based on the matched operation. This is done for each configuration found, and the specs are saved to the temporary tuning directory to be used later in tuning. One main difference in the flow of candidate generation is that state is no longer tracked by saving files to a temporary directory. Instead, ir modules are passed to each function, and only at the very end of candidate generation are the transform dialect specs written to files. This makes things cleaner, since there no longer needs to be a coordination of file paths. Signed-off-by: Max Dawkins --- tuner/examples/test/__init__.py | 5 + tuner/examples/test/__main__.py | 9 ++ tuner/examples/test/tuner_test.py | 40 +++++++ tuner/tuner/candidate_gen.py | 155 ++++++++++++++++++++++++ tuner/tuner/candidate_gen_test.py | 179 ++++++++++++++++++++++++++++ tuner/tuner/common.py | 17 ++- tuner/tuner/dispatch_constraints.py | 6 +- tuner/tuner/dispatch_parser.py | 94 +++++++++++++++ tuner/tuner/dispatch_parser_test.py | 98 +++++++++++++++ tuner/tuner/libtuner.py | 65 ++++++++++ tuner/tuner/op_matchers.py | 178 +++++++++++++++++++++++++++ tuner/tuner/spec_builder.py | 62 ++++++++++ 12 files changed, 900 insertions(+), 8 deletions(-) create mode 100644 tuner/examples/test/__init__.py create mode 100644 tuner/examples/test/__main__.py create mode 100644 tuner/examples/test/tuner_test.py create mode 100644 tuner/tuner/op_matchers.py create mode 100644 tuner/tuner/spec_builder.py diff --git a/tuner/examples/test/__init__.py b/tuner/examples/test/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/tuner/examples/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/tuner/examples/test/__main__.py b/tuner/examples/test/__main__.py new file mode 100644 index 000000000..4f426e110 --- /dev/null +++ b/tuner/examples/test/__main__.py @@ -0,0 +1,9 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from . import tuner_test + +tuner_test.main() diff --git a/tuner/examples/test/tuner_test.py b/tuner/examples/test/tuner_test.py new file mode 100644 index 000000000..d8c35d60b --- /dev/null +++ b/tuner/examples/test/tuner_test.py @@ -0,0 +1,40 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from tuner import libtuner + + +def main(): + args = libtuner.parse_arguments() + + path_config = libtuner.PathConfig() + path_config.base_dir.mkdir(parents=True, exist_ok=True) + path_config.output_unilog.touch() + candidate_trackers: list[libtuner.CandidateTracker] = [] + stop_after_phase: str = args.stop_after + + print("Setup logging") + libtuner.setup_logging(args, path_config) + print(path_config.run_log, end="\n\n") + + if not args.dry_run: + print("Validating devices") + libtuner.validate_devices(args.devices) + print("Validation successful!\n") + + print("Generating candidates...") + candidates = libtuner.generate_candidate_specs( + args, path_config, candidate_trackers + ) + print(f"Stored candidate specs in {path_config.specs_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: + return + + print("Check the detailed execution logs in:") + print(path_config.run_log.resolve()) + + for candidate in candidate_trackers: + libtuner.logging.debug(candidate) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index ed150bfec..ed4d63f7d 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -35,6 +35,7 @@ from .common import * from .dispatch_constraints import * from .dispatch_parser import * +from .spec_builder import * tune_logger = logging.getLogger("tune") @@ -106,6 +107,15 @@ def apply_params( """Apply parameter transformations to the operation.""" pass + @abstractmethod + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + """Generate a transform dialect spec that applies the compilation info attr.""" + pass + class DispatchTunerRegistry: def __init__(self): @@ -130,6 +140,68 @@ def find_handler(self, op_name: str) -> DispatchTuner: assert False, "Dispatch kind not supported" +class ContractionOpInterfaceTuner(DispatchTuner, ContractionOpInterfaceParser): + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> MLIRTransformation: + raise NotImplementedError + + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + contraction_op: ir.Operation = self.get_contraction_operation(ir_module) + lhs_type = ir.ShapedType(contraction_op.operands[0].type) + rhs_type = ir.ShapedType(contraction_op.operands[1].type) + acc_type = ir.ShapedType(contraction_op.operands[2].type) + M = acc_type.get_dim_size(0) + N = acc_type.get_dim_size(1) + K = lhs_type.get_dim_size(1) + # TODO(Max191): Get the function name from the func.func in the input module. + func_name = f"match_contraction_{M}x{N}x{K}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}" + return build_td_spec( + ir_module.context, contraction_op, compilation_info, func_name + ) + + +class ConvolutionOpInterfaceTuner(DispatchTuner, ConvolutionOpInterfaceParser): + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> MLIRTransformation: + raise NotImplementedError + + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + conv_op: ir.Operation = self.get_conv_operation(ir_module) + assert ( + conv_op.name == "linalg.conv_2d_nhwc_hwcf" + ), "expected linalg.conv_2d_nhwc_hwcf" + lhs_type = ir.ShapedType(conv_op.operands[0].type) + rhs_type = ir.ShapedType(conv_op.operands[1].type) + acc_type = ir.ShapedType(conv_op.operands[2].type) + N = acc_type.get_dim_size(0) + H = acc_type.get_dim_size(1) + W = acc_type.get_dim_size(2) + C = rhs_type.get_dim_size(2) + P = rhs_type.get_dim_size(0) + Q = rhs_type.get_dim_size(1) + F = rhs_type.get_dim_size(3) + conv_type = conv_op.name.split(".")[-1] + # TODO(Max191): Get the function name from the func.func in the input module. + func_name = f"match_{conv_type}_{N}x{H}x{W}x{C}x{P}x{Q}x{F}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}" + return build_td_spec(ir_module.context, conv_op, compilation_info, func_name) + + class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, @@ -174,6 +246,13 @@ def apply_params( ) return MLIRTransformation(template, modified, embeddable) + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + raise NotImplementedError + class ConvTuner(DispatchTuner, ConvParser): def get_transform_function_conv( @@ -235,6 +314,13 @@ def apply_params( ) return MLIRTransformation(template, modified, embeddable) + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + raise NotImplementedError + class ContractionTuner(DispatchTuner, ContractionParser): def get_transform_function_broadcast_rhs_mmt( @@ -306,6 +392,13 @@ def apply_params( "", ) + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + raise NotImplementedError + class BatchMmtTuner(DispatchTuner, BatchMmtParser): def get_transform_function_batch_mmt( @@ -353,6 +446,13 @@ def apply_params( ) return MLIRTransformation(template, modified, embeddable) + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + raise NotImplementedError + class BatchMatmulTuner(DispatchTuner, BatchMatmulParser): def get_transform_function_batch_matmul( @@ -409,6 +509,13 @@ def apply_params( ) return MLIRTransformation(template, modified, embeddable) + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + raise NotImplementedError + @dataclass class OpWalkResult: @@ -452,6 +559,7 @@ def get_default_output_dir() -> str: return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") +# TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove in favor of using tune_with_td. def tune( input: str, # Path to the mlir file to be tuned output: str = "", # Path to the output directory, auto creates one if not given @@ -527,6 +635,53 @@ def tune( tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") +def generate_configs_and_td_specs( + input_module: ir.Module, # Path to the mlir file to be tuned + tuner_context: TunerContext, + limit: int = 4096, # Max candidates to be generated + num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints +) -> list[ir.Module]: + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + ContractionOpInterfaceTuner(), + ConvolutionOpInterfaceTuner(), + ] + ) + + walk_result: OpWalkResult = walk_mlir_op(input_module, dispatch_tuner_registry) + + dispatch_tuner = walk_result.dispatch_tuner + assert dispatch_tuner, "No suitable dispatch tuner found" + problem_size: ProblemSize = dispatch_tuner.get_shapes( + str(input_module).splitlines() + ) + tune_logger.debug(str(problem_size)) + + # Index 0 is reserved for default config, so it gets no td spec. + with ir.Location.unknown() as loc: + empty_module = ir.Module.create(loc) + config_specs: list[ir.Module] = [empty_module] + + # Get the MMA intrinisic intructions supported by the target. + variant_op_list = iree_codegen.get_executable_variant_ops(input_module) + assert len(variant_op_list) == 1, "Expect one executable variant op" + variant_op = variant_op_list[0] + mma_list = iree_codegen.query_mma_intrinsics(variant_op) + for i, config in enumerate( + generate_solutions(tuner_context, problem_size, num_subgroups, mma_list) + ): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + td_spec_module = dispatch_tuner.get_td_spec(input_module, config) + assert td_spec_module, "Failed to generate transform dialect spec" + config_specs.append(td_spec_module) + + tune_logger.info(f"Generated {len(config_specs)} tuning specs") + return config_specs + + def main(): parser = argparse.ArgumentParser() parser.add_argument("input", help="Input mlir file", type=str) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 0428ab7d2..d135a8502 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -15,9 +15,11 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore from iree.compiler.dialects import iree_codegen # type: ignore +from iree.compiler.dialects import transform # type: ignore from . import candidate_gen from . import common +from . import op_matchers @pytest.fixture @@ -36,6 +38,183 @@ def remove_comments(mlir: str) -> str: ) +def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None: + context = tuner_ctx.mlir_ctx + module_str = """ + builtin.module{ + func.func @test(%arg0: tensor<2048x2048xf16>, %arg1: tensor<2048x2048xf16>) -> tensor<2048x2048xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<2048x2048xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32> + %2 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + {root_op} + ins(%arg0, %arg1 : tensor<2048x2048xf16>, tensor<2048x2048xf16>) + outs(%1 : tensor<2048x2048xf32>) { + ^bb0(%in: f16, %in_0: f16, %out: f32): + %3 = arith.extf %in : f16 to f32 + %4 = arith.extf %in_0 : f16 to f32 + %5 = arith.mulf %3, %4 : f32 + %6 = arith.addf %out, %5 : f32 + linalg.yield %6 : f32 + } -> tensor<2048x2048xf32> + return %2 : tensor<2048x2048xf32> + } + }""" + + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[8, 8, 0], + reduction=[0, 0, 8], + subgroup_m_count=16, + subgroup_n_count=16, + ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get(prefetch_shared_memory=True) + config_dict = common.get_translation_info_config(pipeline_options, waves_per_eu=8) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 16, config_dict + ) + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info + ) + + ir_module = ir.Module.parse(module_str, context) + + tuner = candidate_gen.ContractionOpInterfaceTuner() + td_spec_module = tuner.get_td_spec(ir_module, compilation_info) + assert td_spec_module + + named_sequence_ops: list[ + transform.NamedSequenceOp + ] = op_matchers.get_ops_from_module( + module=td_spec_module, + fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp), + ) + apply_config_sequence = None + matcher_sequence = None + entry_point = None + for op in named_sequence_ops: + if str(op.opview.sym_name) == '"apply_op_config"': + apply_config_sequence = op + elif str(op.opview.sym_name) == '"__kernel_config"': + entry_point = op + else: + matcher_sequence = op + + assert apply_config_sequence + assert matcher_sequence + assert entry_point + matcher_sequence_str = str(matcher_sequence) + + assert ( + "mma_kind = #iree_gpu.mma_layout" in matcher_sequence_str + ) + assert "subgroup_m_count = 16" in matcher_sequence_str + assert "subgroup_n_count = 16" in matcher_sequence_str + assert "pipeline = LLVMGPUVectorDistribute" in matcher_sequence_str + assert "workgroup_size = [16, 16, 1]" in matcher_sequence_str + assert "subgroup_size = 16" in matcher_sequence_str + assert "workgroup = [8, 8, 0]" in matcher_sequence_str + assert "reduction = [0, 0, 8]" in matcher_sequence_str + assert ( + "gpu_pipeline_options = #iree_gpu.pipeline_options" + in matcher_sequence_str + ) + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in matcher_sequence_str + + +def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None: + context = tuner_ctx.mlir_ctx + module_str = """ + builtin.module{ + func.func @test(%arg0: tensor<2x34x34x2048xi8>, %arg1: tensor<3x3x2048x2048xi8>) -> tensor<2x32x32x2048xi32> { + %cst = arith.constant 0 : i32 + %0 = tensor.empty() : tensor<2x32x32x2048xi32> + %1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<2x32x32x2048xi32>) -> tensor<2x32x32x2048xi32> + %2 = linalg.conv_2d_nhwc_hwcf {root_op} + ins(%arg0, %arg1 : tensor<2x34x34x2048xi8>, tensor<3x3x2048x2048xi8>) + outs(%1 : tensor<2x32x32x2048xi32>) -> tensor<2x32x32x2048xi32> + return %2 : tensor<2x32x32x2048xi32> + } + }""" + + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[1, 1, 464, 320, 0, 0, 0], + reduction=[0, 0, 0, 0, 1, 1, 16], + subgroup_m_count=1, + subgroup_n_count=4, + ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get(prefetch_shared_memory=False) + config_dict = common.get_translation_info_config(pipeline_options, waves_per_eu=2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [256, 1, 1], 64, config_dict + ) + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info + ) + + ir_module = ir.Module.parse(module_str, context) + + tuner = candidate_gen.ConvolutionOpInterfaceTuner() + td_spec_module = tuner.get_td_spec(ir_module, compilation_info) + assert td_spec_module + + named_sequence_ops: list[ + transform.NamedSequenceOp + ] = op_matchers.get_ops_from_module( + module=td_spec_module, + fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp), + ) + apply_config_sequence = None + matcher_sequence = None + entry_point = None + for op in named_sequence_ops: + if str(op.opview.sym_name) == '"apply_op_config"': + apply_config_sequence = op + elif str(op.opview.sym_name) == '"__kernel_config"': + entry_point = op + else: + matcher_sequence = op + + assert apply_config_sequence + assert matcher_sequence + assert entry_point + + matcher_sequence_str = str(matcher_sequence) + + assert ( + "mma_kind = #iree_gpu.mma_layout" in matcher_sequence_str + ) + assert "subgroup_m_count = 1" in matcher_sequence_str + assert "subgroup_n_count = 4" in matcher_sequence_str + assert "pipeline = LLVMGPUVectorDistribute" in matcher_sequence_str + assert "workgroup_size = [256, 1, 1]" in matcher_sequence_str + assert "subgroup_size = 64" in matcher_sequence_str + assert "workgroup = [1, 1, 464, 320, 0, 0, 0]" in matcher_sequence_str + assert "reduction = [0, 0, 0, 0, 1, 1, 16]" in matcher_sequence_str + assert ( + "gpu_pipeline_options = #iree_gpu.pipeline_options" + in matcher_sequence_str + ) + + def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 5c79bd8dd..78e3a8e9d 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -78,6 +78,14 @@ class MatmulSize: B: int = 1 +@dataclass +class ContractionDimensions: + batch: list[int] + m: list[int] + n: list[int] + k: list[int] + + @dataclass class ProblemSize: matmul_size: MatmulSize @@ -98,13 +106,12 @@ def get_compatible_mfma_intrinsics( def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma a_type, b_type, c_type = mma_attr.abc_element_types - if problem_size.res_type.element_type != c_type: + if not isinstance(problem_size.res_type.element_type, type(c_type)): return False if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if ( - problem_size.lhs_type.element_type != a_type - or problem_size.rhs_type.element_type != b_type - ): + if not isinstance( + problem_size.lhs_type.element_type, type(a_type) + ) or not isinstance(problem_size.rhs_type.element_type, type(b_type)): return False return True diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 797c83534..914c04bbf 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -157,9 +157,9 @@ def getMMAAttr( a_type, b_type, c_type = mma_attr.abc_element_types mnk = mma_attr.mnk_shape if ( - a_type == lhs_type - and b_type == rhs_type - and c_type == output_type + isinstance(a_type, type(lhs_type)) + and isinstance(b_type, type(rhs_type)) + and isinstance(c_type, type(output_type)) and m == mnk[0] and n == mnk[1] and k == mnk[2] diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index fe95c52a6..b45771166 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -11,6 +11,7 @@ import re from abc import ABCMeta, abstractmethod +from .op_matchers import * from .common import * @@ -89,6 +90,99 @@ def get_shapes(self, template: list[str]) -> ProblemSize: pass +# TODO(Max191): Support linalg named op versions of contraction ops. The +# current matchers only work for linalg.generic ops. +class ContractionOpInterfaceParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return ( + "matmul_like" in op_name + or "batch_matmul" in op_name + or "batch_matmul_transpose_b" in op_name + or "matmul_transpose_b" in op_name + ) + + def get_contraction_operation( + self, + ir_module: ir.Module, + ) -> Optional[ir.Operation]: + return match_root_op(ir_module, ContractionOpInterfaceMatcher()) + + # TODO(Max191): Pass the ir_module directly instead of the template str. + def get_shapes(self, template: list[str]) -> ProblemSize: + matcher = ContractionOpInterfaceMatcher() + with ir.Context() as ctx: + ir_module = ir.Module.parse("\n".join(template), ctx) + contraction_op = match_root_op(ir_module, matcher) + assert contraction_op is not None, f"contraction op not found" + cdims = matcher.contraction_dimensions + assert cdims, "no contraction dimensions" + assert matcher.lhs_dims, "no lhs dimensions" + assert matcher.rhs_dims, "no rhs dimensions" + assert matcher.res_dims, "no result dimensions" + assert len(cdims.batch) <= 1, f"must have at most 1 batch dimension" + assert len(cdims.m) == 1, f"must have a single m dimension" + assert len(cdims.n) == 1, f"must have a single n dimension" + assert len(cdims.k) == 1, f"must have a single k dimension" + lhs_type = ir.RankedTensorType(contraction_op.operands[0].type) + rhs_type = ir.RankedTensorType(contraction_op.operands[1].type) + res_type = ir.RankedTensorType(contraction_op.operands[2].type) + matmul_size = MatmulSize( + lhs_type.shape[matcher.lhs_dims.index(cdims.m[0])], + rhs_type.shape[matcher.rhs_dims.index(cdims.n[0])], + lhs_type.shape[matcher.lhs_dims.index(cdims.k[0])], + ) + if len(cdims.batch) == 1: + matmul_size.B = lhs_type.shape[matcher.lhs_dims.index(cdims.batch[0])] + return ProblemSize( + matmul_size, + lhs_type=ShapedType(lhs_type.shape, lhs_type.element_type), + rhs_type=ShapedType(rhs_type.shape, rhs_type.element_type), + res_type=ShapedType(res_type.shape, res_type.element_type), + dispatch_kind=DispatchKind.contraction, + ) + + +# TODO(Max191): Support more convolution types. Only NHWC convs are supported. +class ConvolutionOpInterfaceParser(DispatchParser): + def __init__(self): + self.supported_ops = ["linalg.conv_2d_nhwc_hwcf"] + + def supports(self, op_name: str) -> bool: + for supported_op_name in self.supported_ops: + if supported_op_name.split(".")[-1] in op_name: + return True + return False + + def get_conv_operation( + self, + ir_module: ir.Module, + ) -> Optional[ir.Operation]: + return match_root_op(ir_module, NamedOpMatcher(self.supported_ops)) + + # TODO(Max191): Pass the ir_module directly instead of the template str. + def get_shapes(self, template: list[str]) -> ProblemSize: + with ir.Context() as ctx: + ir_module = ir.Module.parse("\n".join(template), ctx) + conv_op = match_root_op(ir_module, NamedOpMatcher(self.supported_ops)) + assert conv_op is not None, f"convolution op not found" + lhs_type = ir.RankedTensorType(conv_op.operands[0].type) + rhs_type = ir.RankedTensorType(conv_op.operands[1].type) + res_type = ir.RankedTensorType(conv_op.operands[2].type) + dim_info = ConvDimInfo.from_rhs_res(rhs_type, res_type) + return ProblemSize( + MatmulSize( + M=dim_info.oh * dim_info.ow, + N=dim_info.oc, + K=dim_info.fh * dim_info.fw * dim_info.ic, + B=dim_info.n, + ), + lhs_type=ShapedType(lhs_type.shape, lhs_type.element_type), + rhs_type=ShapedType(rhs_type.shape, rhs_type.element_type), + res_type=ShapedType(res_type.shape, res_type.element_type), + dispatch_kind=DispatchKind.conv, + ) + + class MmtParser(DispatchParser): def supports(self, op_name: str) -> bool: return "matmul_transpose_b" in op_name diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 9f4afbb19..0b87be659 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -16,6 +16,7 @@ from iree.compiler.dialects import func # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore from iree.compiler.dialects import iree_codegen # type: ignore +from iree.compiler.dialects import linalg # type: ignore from . import common from . import dispatch_parser @@ -40,6 +41,103 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: ) +CONTRACTION_TEMPLATE = r""" +builtin.module{{ + func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{ + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : {res_type} + %1 = linalg.fill ins(%cst : f32) outs(%0 : {res_type}) -> {res_type} + %2 = linalg.generic {{ + indexing_maps = [ + {lhs_map}, + {rhs_map}, + {res_map}], + iterator_types = {iterator_types}}} + {{root_op}} + ins(%arg0, %arg1 : {lhs_type}, {rhs_type}) + outs(%1 : {res_type}) {{ + ^bb0(%in: f16, %in_0: f16, %out: f32): + %3 = arith.extf %in : f16 to f32 + %4 = arith.extf %in_0 : f16 to f32 + %5 = arith.mulf %3, %4 : f32 + %6 = arith.addf %out, %5 : f32 + linalg.yield %6 : f32 + }} -> {res_type} + return %2 : {res_type} + }} +}}""" + + +def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None: + context = tuner_ctx.mlir_ctx + + with ir.Location.unknown(): + transpose_b_str = CONTRACTION_TEMPLATE.format( + lhs_type=ir.RankedTensorType.get([16, 64], ir.F16Type.get()), + rhs_type=ir.RankedTensorType.get([32, 64], ir.F16Type.get()), + res_type=ir.RankedTensorType.get([16, 32], ir.F32Type.get()), + lhs_map="affine_map<(d0, d1, d2) -> (d0, d2)>", + rhs_map="affine_map<(d0, d1, d2) -> (d1, d2)>", + res_map="affine_map<(d0, d1, d2) -> (d0, d1)>", + iterator_types='["parallel", "parallel", "reduction"]', + ) + module = ir.Module.parse(transpose_b_str, context) + parser = dispatch_parser.ContractionOpInterfaceParser() + mmt_op = parser.get_contraction_operation(module) + assert mmt_op is not None + assert isinstance(mmt_op.opview, linalg.GenericOp) + shapes: common.ProblemSize = parser.get_shapes(transpose_b_str.splitlines()) + assert shapes.matmul_size.B == 1 + assert shapes.matmul_size.M == 16 + assert shapes.matmul_size.N == 32 + assert shapes.matmul_size.K == 64 + assert shapes.lhs_type.shape == [16, 64] + assert isinstance(shapes.lhs_type.element_type, ir.F16Type) + assert shapes.rhs_type.shape == [32, 64] + assert isinstance(shapes.rhs_type.element_type, ir.F16Type) + assert shapes.res_type.shape == [16, 32] + assert isinstance(shapes.res_type.element_type, ir.F32Type) + + with ir.Location.unknown(): + bmm_transposed_inputs_str = CONTRACTION_TEMPLATE.format( + lhs_type=ir.RankedTensorType.get([5, 8, 128], ir.F16Type.get()), + rhs_type=ir.RankedTensorType.get([128, 40, 5], ir.F16Type.get()), + res_type=ir.RankedTensorType.get([5, 40, 8], ir.F32Type.get()), + lhs_map="affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>", + rhs_map="affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>", + res_map="affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>", + iterator_types='["parallel", "parallel", "parallel", "reduction"]', + ) + module = ir.Module.parse(bmm_transposed_inputs_str, context) + mmt_op = parser.get_contraction_operation(module) + shapes = parser.get_shapes(bmm_transposed_inputs_str.splitlines()) + assert shapes.matmul_size.B == 5 + assert shapes.matmul_size.M == 8 + assert shapes.matmul_size.N == 40 + assert shapes.matmul_size.K == 128 + + +def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None: + context = tuner_ctx.mlir_ctx + module_str = """ + builtin.module{ + func.func @test(%arg0: tensor<2x34x34x16xi8>, %arg1: tensor<3x3x16x16xi8>) -> tensor<2x32x32x16xi32> { + %cst = arith.constant 0 : i32 + %0 = tensor.empty() : tensor<2x32x32x16xi32> + %1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32> + %2 = linalg.conv_2d_nhwc_hwcf {root_op} + ins(%arg0, %arg1 : tensor<2x34x34x16xi8>, tensor<3x3x16x16xi8>) + outs(%1 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32> + return %2 : tensor<2x32x32x16xi32> + } + }""" + module = ir.Module.parse(module_str, context) + parser = dispatch_parser.ConvolutionOpInterfaceParser() + conv_op = parser.get_conv_operation(module) + assert conv_op is not None + assert isinstance(conv_op.opview, linalg.Conv2DNhwcHwcfOp) + + def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 3c195520c..6bece17f4 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -39,7 +39,10 @@ import json from abc import ABC, abstractmethod import iree.runtime as ireert # type: ignore +from iree.compiler import ir # type: ignore from . import candidate_gen +from . import dispatch_parser +from .common import * # Default values for num_candidates and devices, change it as needed @@ -62,6 +65,7 @@ @dataclass class CandidateTracker: candidate_id: int + mlir_path: Optional[Path] = None dispatch_mlir_path: Optional[Path] = None dispatch_config_path: Optional[Path] = None configuration: Optional[candidate_gen.iree_codegen.CompilationInfoAttr] = None @@ -746,6 +750,7 @@ def append_to_file(lines: list[str], filepath: Path, title: str = "") -> None: file.write("\n") +# TODO(Max191): Remove in favor of using generate_candidate_specs. def generate_candidates( args: argparse.Namespace, path_config: PathConfig, @@ -825,6 +830,66 @@ def generate_candidates( return candidates +def generate_candidate_specs( + args: argparse.Namespace, + path_config: PathConfig, + candidate_trackers: list[CandidateTracker], +) -> list[int]: + """Generate candidate transform dialect specs for tuning. Returns the list of candidate indexes""" + logging.debug("generate_candidate_specs()") + + path_config.specs_dir.mkdir(parents=True, exist_ok=True) + tune_logger = logging.getLogger("tune") + + # Generate transform dialect specs. + try: + with open(args.input_file, "r") as f: + mlir_text = f.read() + with ir.Context() as ctx: + tuner_context = TunerContext(ctx, tune_logger) + mlir_module = dispatch_parser.parse_mlir(mlir_text, tuner_context) + logging.debug("Captured messages from candidate_gen.py:") + config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs( + input_module=mlir_module, + tuner_context=tuner_context, + limit=args.num_candidates, + num_subgroups=args.num_subgroups, + ) + logging.debug("candidate_gen.py ends") + handle_error( + condition=(len(config_specs) <= 1), msg="Failed to generate any candidates" + ) + + # Create candidate trackers. + candidates = [] + for candidate_num, spec in enumerate(config_specs): + candidates.append(candidate_num) + # Move the specs to the canonical path_config location. + spec_path = path_config.specs_dir / path_config.get_candidate_spec_filename( + candidate_num + ) + with open(spec_path, "w") as f: + f.write(str(spec)) + new_candidate = CandidateTracker( + mlir_path=args.input_file, + candidate_id=candidate_num, + spec_path=spec_path, + ) + candidate_trackers.append(new_candidate) + except Exception as e: + logging.error("An error occurred during candidates generation: %s", str(e)) + # Capture and log debug messages from candidate_gen.py. + tune_logger = logging.getLogger("tune_with_td") + for handler in logging.getLogger().handlers: + if isinstance(handler, logging.FileHandler): + tune_logger.handlers.append(handler) + tune_logger.exception("Error in candidate_gen.py:") + raise + + logging.info(f"Generated [{len(candidates) - 1}] candidates") + return candidates + + def collision_handler(index_hash_list: list[tuple[int, str]]) -> tuple[bool, list[int]]: """If a collision is found, generate a list of new indexes. If no collision, `unique_indexes = []`""" # Check if candidate produces tbe same .vmfb diff --git a/tuner/tuner/op_matchers.py b/tuner/tuner/op_matchers.py new file mode 100644 index 000000000..1abdafd3d --- /dev/null +++ b/tuner/tuner/op_matchers.py @@ -0,0 +1,178 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This code implements matcher functions for MLIR modules using python bindings. + +from abc import ABCMeta, abstractmethod + +from .common import * +from iree.compiler import ir # type: ignore + + +class OpMatcher(metaclass=ABCMeta): + @abstractmethod + def match(self, op: ir.Operation) -> bool: + """Check if the op passes the matching criteria.""" + pass + + +def walk_collect_ops( + op: ir.Operation, + ops: list[ir.Operation], + fn, +) -> ir.WalkResult: + if fn(op): + ops.append(op) + return ir.WalkResult.ADVANCE + + +def get_ops_from_module(module: ir.Module, fn): + ops: list[ir.Operation] = [] + for op in module.body.operations: + op.walk( + lambda op: walk_collect_ops(op, ops, fn), + ir.WalkOrder.POST_ORDER, + ) + return ops + + +def is_root_op(op: ir.Operation) -> bool: + for attr in op.opview.attributes: + if attr.name == "root_op": + return True + return False + + +def match_root_op( + ir_module: ir.Module, + matcher: OpMatcher, +) -> Optional[ir.Operation]: + root_ops: list[ir.Operation] = get_ops_from_module(ir_module, is_root_op) + if len(root_ops) != 1: + return None + if not matcher.match(root_ops[0].operation): + return None + return root_ops[0] + + +class NamedOpMatcher(OpMatcher): + def __init__(self, op_names: list[str]): + self.op_names = op_names + + def match(self, op: ir.Operation) -> bool: + return op.name in self.op_names + + +# TODO(Max191): Add logic to match the body of the generic op. +class GenericOpMatcher(NamedOpMatcher): + def __init__(self): + super().__init__(["linalg.generic"]) + + @abstractmethod + def match_operands(self, operands: ir.OpOperandList) -> bool: + """Match the operands of the linalg op.""" + pass + + @abstractmethod + def match_indexing_maps(self, maps: list[ir.AffineMap]) -> bool: + """Match the indexing_maps of the linalg op.""" + pass + + def match(self, op: ir.Operation) -> bool: + if not super().match(op): + return False + + if not self.match_operands(op.operands): + return False + + maps_attr = None + for attr in op.opview.attributes: + if attr.name == "indexing_maps" and isinstance(attr.attr, ir.ArrayAttr): + maps_attr = attr.attr + if maps_attr is None: + return False + + maps: list[ir.AffineMap] = [] + for map in maps_attr: + maps.append(map.value) + if not self.match_indexing_maps(maps): + return False + + return True + + +def get_map_result_dim_positions(map: ir.AffineMap): + exprs = [] + if not map.is_projected_permutation: + return None + for expr in map.results: + dim_str = str(expr) + if len(dim_str) < 1: + return None + if dim_str[0] != "d": + return None + if not dim_str[1:].isdigit(): + return None + dim_position = int(dim_str[1:]) + exprs.append(dim_position) + return exprs + + +class ContractionOpInterfaceMatcher(GenericOpMatcher): + def __init__(self): + super().__init__() + self.contraction_dimensions: Optional[ContractionDimensions] = None + self.lhs_dims: Optional[list[int]] = None + self.rhs_dims: Optional[list[int]] = None + self.res_dims: Optional[list[int]] = None + + def match_operands(self, operands: ir.OpOperandList) -> bool: + if len(operands) != 3: + return False + for operand in operands: + if not isinstance(operand.type, ir.ShapedType): + return False + return True + + def match_indexing_maps(self, maps: list[ir.AffineMap]) -> bool: + if len(maps) != 3: + return False + lhs_dims = get_map_result_dim_positions(maps[0]) + rhs_dims = get_map_result_dim_positions(maps[1]) + res_dims = get_map_result_dim_positions(maps[2]) + if lhs_dims is None or rhs_dims is None or res_dims is None: + return False + + batch_dims = [] + m_dims = [] + n_dims = [] + k_dims = [] + + for d in range(maps[0].n_dims): + if d in lhs_dims and d in rhs_dims and d in res_dims: + batch_dims.append(d) + continue + if d in lhs_dims and d in res_dims: + m_dims.append(d) + continue + if d in rhs_dims and d in res_dims: + n_dims.append(d) + continue + if d in lhs_dims and d in rhs_dims: + k_dims.append(d) + continue + return False + + self.contraction_dimensions = ContractionDimensions( + batch=batch_dims, + m=m_dims, + n=n_dims, + k=k_dims, + ) + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.res_dims = res_dims + return True diff --git a/tuner/tuner/spec_builder.py b/tuner/tuner/spec_builder.py new file mode 100644 index 000000000..a27bd072f --- /dev/null +++ b/tuner/tuner/spec_builder.py @@ -0,0 +1,62 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore + +from .common import * +from .dispatch_constraints import * +from .dispatch_parser import * + + +# TODO(Max191): Use python bindings to build the transform dialect spec module +# instead of using string formatting. +def build_td_spec( + context: ir.Context, + op: ir.Operation, + compilation_info: iree_codegen.CompilationInfoAttr, + func_name: str, +) -> ir.Module: + bbargs = [] + for operand in op.operands: + ssa_name = operand.get_name() + operand_type = operand.type + bbargs.append(f"{ssa_name}: {operand_type}") + bbargs_str = ", ".join(bbargs) + root_operation = str(op) + spec_text = f""" + module attributes {{ transform.with_named_sequence }} {{ + // Annotation Transform + transform.named_sequence @apply_op_config(%op: !transform.any_op {{transform.readonly}}, + %config: !transform.any_param {{transform.readonly}}) {{ + transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param + transform.yield + }} + + // Custom Op Matcher + transform.named_sequence @{func_name}(%cont: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont {{ + ^bb0({bbargs_str}): + {root_operation} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant {compilation_info} -> !transform.any_param + transform.yield %cont, %config : !transform.any_op, !transform.any_param + }} + + // Entry Point + transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {{transform.consumed}}) {{ + transform.foreach_match in %variant_op + @{func_name} -> @apply_op_config + : (!transform.any_op) -> (!transform.any_op) + transform.yield + }} + }} + """ + return ir.Module.parse(spec_text, context) From ba78824e5cb62fcd48523bc83747311c78806d87 Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Mon, 16 Dec 2024 12:04:43 -0800 Subject: [PATCH 17/39] Update llama tests for block size 32 (#696) The block_seq_stride default is changing to 32 instead of 16, so this PR updates the tests to use the block_seq_stride flag and the new numpy inputs for block size 32 to benchmark correctly. This PR also removes the decomposed fp16 tests that are not needed anymore. --------- Signed-off-by: aviator19941 --- sharktank/sharktank/utils/export_artifacts.py | 5 +- .../models/llama/benchmark_amdgpu_test.py | 215 +++--------------- 2 files changed, 32 insertions(+), 188 deletions(-) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index ec75d597e..0bf252525 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -92,6 +92,7 @@ def __init__( iree_hal_target_backends: str, attention_kernel: str, tensor_parallelism_size: int, + block_seq_stride: Optional[int] = None, ): self.sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent @@ -102,6 +103,7 @@ def __init__( self.iree_hal_target_backends = iree_hal_target_backends self.attention_kernel = attention_kernel self.tensor_parallelism_size = tensor_parallelism_size + self.block_seq_stride = block_seq_stride def timeit(func): def wrapper(*args, **kwargs): @@ -184,6 +186,8 @@ def export_to_mlir( if self.attention_kernel in ["decomposed", "torch"]: export_args.append("--attention-kernel") export_args.append(self.attention_kernel) + if self.block_seq_stride: + export_args.append(f"--block-seq-stride={self.block_seq_stride}") cwd = self.sharktank_dir cmd = subprocess.list2cmdline(export_args) @@ -280,7 +284,6 @@ def iree_benchmark_vmfb( benchmark_args += [ "iree-benchmark-module", "--hip_use_streams=true", - "--hip_allow_inline_execution=true", "--device_allocator=caching", f"--module={vmfb_name}", ] diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 13a2c35e4..0c45bdffa 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -74,14 +74,6 @@ def setUp(self): self.dir_path_8b = self.dir_path / "llama-8b" self.temp_dir_8b = Path(self.dir_path_8b) self.temp_dir_8b.mkdir(parents=True, exist_ok=True) - self.llama8b_f16_decomposed_artifacts = ExportArtifacts( - irpa_path=str(self.irpa_path), - batch_size=4, - iree_hip_target="gfx942", - iree_hal_target_backends="rocm", - attention_kernel="decomposed", - tensor_parallelism_size=self.tensor_parallelism_size, - ) self.llama8b_f16_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path), batch_size=4, @@ -89,6 +81,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=32, ) self.llama8b_fp8_decomposed_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -97,6 +90,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="decomposed", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=32, ) self.llama8b_fp8_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -105,48 +99,42 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=32, ) - self.prefill_args_f16 = self.artifacts_dir / "prefill_args" - self.prefill_args_bs4_128_in_tokens_f16 = ( - self.artifacts_dir / "prefill_args_bs4_128" + self.prefill_args_bs4_128_in_tokens_stride_32_f16 = ( + self.artifacts_dir / "prefill_args_bs4_128_stride_32" ) self.prefill_args_bs4_2048_in_tokens_f16 = ( self.artifacts_dir / "prefill_args_bs4_2048" ) - self.decode_args_f16 = self.artifacts_dir / "decode_args" + self.decode_args_bs4_128_in_tokens_stride_32_f16 = ( + self.artifacts_dir / "decode_args_bs4_128_stride_32" + ) self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8" self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8" - self.iree_run_prefill_args = [ - "--function=prefill_bs4", - f"--input=@{self.prefill_args_f16}/tokens.npy", - f"--input=@{self.prefill_args_f16}/seq_lens.npy", - f"--input=@{self.prefill_args_f16}/seq_block_ids.npy", - f"--input=@{self.prefill_args_f16}/cache_state_f16.npy", - "--benchmark_repetitions=3", - ] self.iree_run_prefill_nondecomposed_args_fp16 = [ "--function=prefill_bs4", - f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/random_tokens.npy", - f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_lens.npy", - f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_block_ids.npy", - f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/cs_f16.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_stride_32_f16}/tokens.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_stride_32_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_stride_32_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_stride_32_f16}/cs_f16.npy", "--benchmark_repetitions=3", ] self.iree_run_prefill_nondecomposed_args_fp16_2048 = [ "--function=prefill_bs4", - f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/tokens_2048.npy", + f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/tokens.npy", f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/seq_lens.npy", f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/seq_block_ids.npy", f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/cs_f16.npy", "--benchmark_repetitions=3", ] - self.iree_run_decode_args = [ + self.iree_run_decode_nondecomposed_args_f16 = [ "--function=decode_bs4", - f"--input=@{self.decode_args_f16}/tokens.npy", - f"--input=@{self.decode_args_f16}/seq_lens.npy", - f"--input=@{self.decode_args_f16}/start_positions.npy", - f"--input=@{self.decode_args_f16}/seq_block_ids.npy", - f"--input=@{self.decode_args_f16}/cache_state_f16.npy", + f"--input=@{self.decode_args_bs4_128_in_tokens_stride_32_f16}/next_tokens.npy", + f"--input=@{self.decode_args_bs4_128_in_tokens_stride_32_f16}/seq_lens.npy", + f"--input=@{self.decode_args_bs4_128_in_tokens_stride_32_f16}/start_positions.npy", + f"--input=@{self.decode_args_bs4_128_in_tokens_stride_32_f16}/seq_block_ids.npy", + f"--input=@{self.decode_args_bs4_128_in_tokens_stride_32_f16}/cs_f16.npy", "--benchmark_repetitions=3", ] self.iree_run_prefill_args_fp8 = [ @@ -167,46 +155,6 @@ def setUp(self): "--benchmark_repetitions=3", ] - def testBenchmark8B_f16_Decomposed(self): - output_file_name = self.dir_path_8b / "f16_decomposed" - output_mlir = self.llama8b_f16_decomposed_artifacts.create_file( - suffix=".mlir", prefix=output_file_name - ) - output_json = self.llama8b_f16_decomposed_artifacts.create_file( - suffix=".json", prefix=output_file_name - ) - output_vmfb = self.llama8b_f16_decomposed_artifacts.create_file( - suffix=".vmfb", prefix=output_file_name - ) - export_return_code = self.llama8b_f16_decomposed_artifacts.export_to_mlir( - mlir_path=output_mlir, - json_path=output_json, - ) - self.llama8b_f16_decomposed_artifacts.compile_to_vmfb( - mlir_path=str(output_mlir), - vmfb_path=output_vmfb, - hal_dump_path=output_file_name, - cwd=self.repo_root, - args=self.compile_args, - ) - # benchmark prefill - self.llama8b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_prefill_args, - cwd=self.repo_root, - ) - # benchmark decode - self.llama8b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_decode_args, - cwd=self.repo_root, - ) - - @skipif_run_quick_llama_test def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_128(self): output_file_name = self.dir_path_8b / "f16_torch_prefill_128" output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( @@ -218,7 +166,6 @@ def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_128(self): output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file( suffix=".vmfb", prefix=output_file_name ) - self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch" export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir( mlir_path=output_mlir, json_path=output_json, @@ -252,7 +199,7 @@ def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_2048(self): output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file( suffix=".vmfb", prefix=output_file_name ) - self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch" + self.llama8b_f16_torch_sdpa_artifacts.block_seq_stride = 16 export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir( mlir_path=output_mlir, json_path=output_json, @@ -275,7 +222,6 @@ def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_2048(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def testBenchmark8B_f16_Non_Decomposed(self): output_file_name = self.dir_path_8b / "f16_torch" output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( @@ -287,7 +233,6 @@ def testBenchmark8B_f16_Non_Decomposed(self): output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file( suffix=".vmfb", prefix=output_file_name ) - self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch" export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir( mlir_path=output_mlir, json_path=output_json, @@ -304,7 +249,7 @@ def testBenchmark8B_f16_Non_Decomposed(self): hip_device_id=self.iree_device, vmfb_name=output_vmfb, irpa_path=self.irpa_path, - args=self.iree_run_prefill_args, + args=self.iree_run_prefill_nondecomposed_args_fp16, cwd=self.repo_root, ) # benchmark decode @@ -312,7 +257,7 @@ def testBenchmark8B_f16_Non_Decomposed(self): hip_device_id=self.iree_device, vmfb_name=output_vmfb, irpa_path=self.irpa_path, - args=self.iree_run_decode_args, + args=self.iree_run_decode_nondecomposed_args_f16, cwd=self.repo_root, ) @@ -410,14 +355,6 @@ def setUp(self): self.dir_path_70b = self.dir_path / "llama-70b" self.temp_dir_70b = Path(self.dir_path_70b) self.temp_dir_70b.mkdir(parents=True, exist_ok=True) - self.llama70b_f16_decomposed_artifacts = ExportArtifacts( - irpa_path=str(self.irpa_path), - batch_size=4, - iree_hip_target="gfx942", - iree_hal_target_backends="rocm", - attention_kernel="decomposed", - tensor_parallelism_size=self.tensor_parallelism_size, - ) self.llama70b_f16_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path), batch_size=4, @@ -425,6 +362,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.llama70b_fp8_decomposed_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -433,6 +371,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="decomposed", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.llama70b_fp8_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -441,6 +380,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.prefill_args_f16 = self.artifacts_dir / "prefill_args" self.prefill_args_bs4_128_in_tokens_f16 = ( @@ -495,52 +435,6 @@ def setUp(self): @pytest.mark.xfail( reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException ) - def testBenchmark70B_f16_TP8_Decomposed(self): - output_file_name = self.dir_path_70b / "f16_decomposed" - output_mlir = self.llama70b_f16_decomposed_artifacts.create_file( - suffix=".mlir", prefix=output_file_name - ) - output_json = self.llama70b_f16_decomposed_artifacts.create_file( - suffix=".json", prefix=output_file_name - ) - output_vmfb = self.llama70b_f16_decomposed_artifacts.create_file( - suffix=".vmfb", prefix=output_file_name - ) - output_shard_file_name = ( - self.artifacts_dir - / f"fp16/tp8/llama3.1_70b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa" - ) - if output_shard_file_name.exists(): - self.irpa_path = output_shard_file_name - export_return_code = self.llama70b_f16_decomposed_artifacts.export_to_mlir( - mlir_path=output_mlir, - json_path=output_json, - ) - self.llama70b_f16_decomposed_artifacts.compile_to_vmfb( - mlir_path=str(output_mlir), - vmfb_path=output_vmfb, - hal_dump_path=output_file_name, - cwd=self.repo_root, - args=self.compile_args, - ) - # benchmark prefill - self.llama70b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_prefill_args, - cwd=self.repo_root, - ) - # benchmark decode - self.llama70b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_decode_args, - cwd=self.repo_root, - ) - - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def testBenchmark70B_f16_TP8_Non_Decomposed(self): output_file_name = self.dir_path_70b / "f16_torch" output_mlir = self.llama70b_f16_torch_sdpa_artifacts.create_file( @@ -697,14 +591,6 @@ def setUp(self): self.dir_path_405b = self.dir_path / "llama-405b" self.temp_dir_405b = Path(self.dir_path_405b) self.temp_dir_405b.mkdir(parents=True, exist_ok=True) - self.llama405b_f16_decomposed_artifacts = ExportArtifacts( - irpa_path=str(self.irpa_path), - batch_size=4, - iree_hip_target="gfx942", - iree_hal_target_backends="rocm", - attention_kernel="decomposed", - tensor_parallelism_size=self.tensor_parallelism_size, - ) self.llama405b_f16_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path), batch_size=4, @@ -712,6 +598,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.llama405b_fp8_decomposed_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -720,6 +607,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="decomposed", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.llama405b_fp8_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -728,6 +616,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.prefill_args_f16 = self.artifacts_dir / "prefill_args" self.prefill_args_bs4_128_in_tokens_f16 = ( @@ -779,54 +668,6 @@ def setUp(self): "--benchmark_repetitions=3", ] - @pytest.mark.xfail( - reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException - ) - def testBenchmark405B_f16_TP8_Decomposed(self): - output_file_name = self.dir_path_405b / "f16_decomposed" - output_mlir = self.llama405b_f16_decomposed_artifacts.create_file( - suffix=".mlir", prefix=output_file_name - ) - output_json = self.llama405b_f16_decomposed_artifacts.create_file( - suffix=".json", prefix=output_file_name - ) - output_vmfb = self.llama405b_f16_decomposed_artifacts.create_file( - suffix=".vmfb", prefix=output_file_name - ) - output_shard_file_name = ( - self.artifacts_dir - / f"fp16/tp8/llama3.1_405b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa" - ) - if output_shard_file_name.exists(): - self.irpa_path = output_shard_file_name - export_return_code = self.llama405b_f16_decomposed_artifacts.export_to_mlir( - mlir_path=output_mlir, - json_path=output_json, - ) - self.llama405b_f16_decomposed_artifacts.compile_to_vmfb( - mlir_path=str(output_mlir), - vmfb_path=output_vmfb, - hal_dump_path=output_file_name, - cwd=self.repo_root, - args=self.compile_args, - ) - # benchmark prefill - self.llama405b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_prefill_args, - cwd=self.repo_root, - ) - # benchmark decode - self.llama405b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_decode_args, - cwd=self.repo_root, - ) - @pytest.mark.xfail( reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException ) From cfadf2aa9930686a0425c60a0a3e2d99a2a0d6c8 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Mon, 16 Dec 2024 16:46:45 -0500 Subject: [PATCH 18/39] Add a asyncio.Lock to debug dumper to prevent overwrites when multiple debug dump requests are simultaneously issued (#705) --- .../llm/components/service_debug_dumper.py | 344 +++++++++--------- 1 file changed, 174 insertions(+), 170 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py b/shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py index ae492eab3..90842c44d 100644 --- a/shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py +++ b/shortfin/python/shortfin_apps/llm/components/service_debug_dumper.py @@ -27,6 +27,7 @@ def __init__(self): self.debug_data_dir / "llm_service_invocation_dumps" / self.boot_timestamp ) self.dump_dir.mkdir(parents=True, exist_ok=False) + self._dump_lock = asyncio.Lock() logger.info( f"[debug_service.py] Please find debug dumps for service.py in {self.dump_dir}" ) @@ -35,181 +36,184 @@ async def pre_invocation_debug_dump( self, executor: "InferenceExecutorProcess", local_vars: Dict[str, Any] ): """Comprehensive debug dump before inference invocation.""" - # Extract variables from locals - is_decode = local_vars["is_decode"] - device0 = local_vars["device0"] - fn = local_vars["fn"] - req_bs = local_vars["req_bs"] - bsl = local_vars["bsl"] - seq_stride = local_vars["seq_stride"] - block_count = local_vars["block_count"] - req_count = local_vars["req_count"] - tokens = local_vars["tokens"] - start_positions = local_vars.get("start_positions") - seq_lens = local_vars["seq_lens"] - seq_block_ids = local_vars["seq_block_ids"] - args = local_vars["args"] - - phase = executor.phase - exec_requests = executor.exec_requests - model_params = executor.service.model_params - - dump_path = self.dump_dir / f"{self.dump_id}" - dump_path.mkdir(parents=True, exist_ok=True) - - # Prepare debug info dictionary - debug_info = { - "metadata": { - "dump_id": self.dump_id, - "dump_timestamp": datetime.now().isoformat(), - "phase": str(phase), - "is_decode": is_decode, - "device": str(device0), - "function": str(fn), - }, - "batch_info": { - "request_batch_size": req_bs, - "block_sequence_length": int(bsl), - "sequence_stride": seq_stride, - "block_count": block_count, - "actual_request_count": req_count, - }, - "requests": [ - { - "index": i, - "start_position": req.start_position, - "rid": req.rid, - "input_token_ids": req.input_token_ids.tolist() - if hasattr(req.input_token_ids, "tolist") - else list(req.input_token_ids), - "input_length": len(req.input_token_ids), - "cache_pages": req.cache_page_indices(block_count), - } - for i, req in enumerate(exec_requests) - ], - "tensor_shapes": { - "tokens": tokens.shape, - **({"start_positions": start_positions.shape} if is_decode else {}), - "seq_lens": seq_lens.shape, - "seq_block_ids": seq_block_ids.shape, - }, - "tensor_values": { - "tokens": tokens.for_transfer().items.tolist() - if hasattr(tokens.for_transfer().items, "tolist") - else list(tokens.for_transfer().items), - **( + async with self._dump_lock: + # Extract variables from locals + is_decode = local_vars["is_decode"] + device0 = local_vars["device0"] + fn = local_vars["fn"] + req_bs = local_vars["req_bs"] + bsl = local_vars["bsl"] + seq_stride = local_vars["seq_stride"] + block_count = local_vars["block_count"] + req_count = local_vars["req_count"] + tokens = local_vars["tokens"] + start_positions = local_vars.get("start_positions") + seq_lens = local_vars["seq_lens"] + seq_block_ids = local_vars["seq_block_ids"] + args = local_vars["args"] + + phase = executor.phase + exec_requests = executor.exec_requests + model_params = executor.service.model_params + + dump_path = self.dump_dir / f"{self.dump_id}" + dump_path.mkdir(parents=True, exist_ok=True) + + # Prepare debug info dictionary + debug_info = { + "metadata": { + "dump_id": self.dump_id, + "dump_timestamp": datetime.now().isoformat(), + "phase": str(phase), + "is_decode": is_decode, + "device": str(device0), + "function": str(fn), + }, + "batch_info": { + "request_batch_size": req_bs, + "block_sequence_length": int(bsl), + "sequence_stride": seq_stride, + "block_count": block_count, + "actual_request_count": req_count, + }, + "requests": [ { - "start_positions": start_positions.for_transfer().items.tolist() - if hasattr(start_positions.for_transfer().items, "tolist") - else list(start_positions.for_transfer().items) + "index": i, + "start_position": req.start_position, + "rid": req.rid, + "input_token_ids": req.input_token_ids.tolist() + if hasattr(req.input_token_ids, "tolist") + else list(req.input_token_ids), + "input_length": len(req.input_token_ids), + "cache_pages": req.cache_page_indices(block_count), } - if is_decode - else {} - ), - "sequence_lengths": seq_lens.for_transfer().items.tolist() - if hasattr(seq_lens.for_transfer().items, "tolist") - else list(seq_lens.for_transfer().items), - "sequence_block_ids": seq_block_ids.for_transfer().items.tolist() - if hasattr(seq_block_ids.for_transfer().items, "tolist") - else list(seq_block_ids.for_transfer().items), - }, - "model_config": { - "prefill_batch_sizes": model_params.prefill_batch_sizes, - "decode_batch_sizes": model_params.decode_batch_sizes, - "attn_dtype": str(model_params.attn_dtype), - "paged_kv_cache": { - "device_block_count": model_params.paged_kv_cache.device_block_count, - "block_seq_stride": model_params.paged_kv_cache.block_seq_stride, - "prefix_sharing_algorithm": model_params.paged_kv_cache.prefix_sharing_algorithm, + for i, req in enumerate(exec_requests) + ], + "tensor_shapes": { + "tokens": tokens.shape, + **({"start_positions": start_positions.shape} if is_decode else {}), + "seq_lens": seq_lens.shape, + "seq_block_ids": seq_block_ids.shape, + }, + "tensor_values": { + "tokens": tokens.for_transfer().items.tolist() + if hasattr(tokens.for_transfer().items, "tolist") + else list(tokens.for_transfer().items), + **( + { + "start_positions": start_positions.for_transfer().items.tolist() + if hasattr(start_positions.for_transfer().items, "tolist") + else list(start_positions.for_transfer().items) + } + if is_decode + else {} + ), + "sequence_lengths": seq_lens.for_transfer().items.tolist() + if hasattr(seq_lens.for_transfer().items, "tolist") + else list(seq_lens.for_transfer().items), + "sequence_block_ids": seq_block_ids.for_transfer().items.tolist() + if hasattr(seq_block_ids.for_transfer().items, "tolist") + else list(seq_block_ids.for_transfer().items), + }, + "model_config": { + "prefill_batch_sizes": model_params.prefill_batch_sizes, + "decode_batch_sizes": model_params.decode_batch_sizes, + "attn_dtype": str(model_params.attn_dtype), + "paged_kv_cache": { + "device_block_count": model_params.paged_kv_cache.device_block_count, + "block_seq_stride": model_params.paged_kv_cache.block_seq_stride, + "prefix_sharing_algorithm": model_params.paged_kv_cache.prefix_sharing_algorithm, + }, }, - }, - } - - # Save debug info as JSON - with open(dump_path / "info.json", "w") as f: - json.dump(debug_info, f, indent=2) - - # Save program arguments - path = dump_path - args_np = [] - for i, a in enumerate(args): - host_array = a.for_transfer() - host_array.copy_from(a) - await a.device - args_np.append(np.array(host_array)) - - # Save binary numpy arrays - for i, arr in enumerate(args_np): - np.save(path / f"{i}.npy", arr) - - # Generate human-readable report - with open(path / "saved_program_args.txt", "w") as f: + } + + # Save debug info as JSON + with open(dump_path / "info.json", "w") as f: + json.dump(debug_info, f, indent=2) + + # Save program arguments + path = dump_path + args_np = [] + for i, a in enumerate(args): + host_array = a.for_transfer() + host_array.copy_from(a) + await a.device + args_np.append(np.array(host_array)) + + # Save binary numpy arrays for i, arr in enumerate(args_np): - f.write(f"\n{'='*80}\n") - f.write(f"{i}.npy:\n") - f.write(f"{'='*80}\n\n") - - # Basic info - f.write(f"Shape: {arr.shape}\n") - f.write(f"Dtype: {arr.dtype}\n") - f.write(f"Total elements: {arr.size}\n") - f.write(f"Dimensions: {arr.ndim}\n\n") - - # Stats - f.write("Statistics:\n") - nan_count = np.count_nonzero(np.isnan(arr)) - inf_count = np.count_nonzero(np.isinf(arr)) - f.write(f"- NaN count: {nan_count}\n") - f.write(f"- Inf count: {inf_count}\n") - - if nan_count == 0 and inf_count == 0: - f.write(f"- Min: {np.min(arr)}\n") - f.write(f"- Max: {np.max(arr)}\n") - f.write(f"- Mean: {np.mean(arr):.6f}\n") - f.write(f"- Median: {np.median(arr):.6f}\n") - f.write(f"- Range: {np.ptp(arr)}\n") - try: - mode = pd.Series(arr.flatten()).mode().iloc[0] - f.write(f"- Mode: {mode}\n") - except: - f.write("- Mode: Unable to compute\n") - - if np.issubdtype(arr.dtype, np.number): + np.save(path / f"{i}.npy", arr) + + # Generate human-readable report + with open(path / "saved_program_args.txt", "w") as f: + for i, arr in enumerate(args_np): + f.write(f"\n{'='*80}\n") + f.write(f"{i}.npy:\n") + f.write(f"{'='*80}\n\n") + + # Basic info + f.write(f"Shape: {arr.shape}\n") + f.write(f"Dtype: {arr.dtype}\n") + f.write(f"Total elements: {arr.size}\n") + f.write(f"Dimensions: {arr.ndim}\n\n") + + # Stats + f.write("Statistics:\n") + nan_count = np.count_nonzero(np.isnan(arr)) + inf_count = np.count_nonzero(np.isinf(arr)) + f.write(f"- NaN count: {nan_count}\n") + f.write(f"- Inf count: {inf_count}\n") + + if nan_count == 0 and inf_count == 0: + f.write(f"- Min: {np.min(arr)}\n") + f.write(f"- Max: {np.max(arr)}\n") + f.write(f"- Mean: {np.mean(arr):.6f}\n") + f.write(f"- Median: {np.median(arr):.6f}\n") + f.write(f"- Range: {np.ptp(arr)}\n") try: - hist, bins = np.histogram(arr.flatten(), bins="auto") - f.write("\nHistogram:\n") - f.write( - "Bins: " - + pformat(bins.tolist(), width=80, compact=True) - + "\n" - ) - f.write( - "Counts: " - + pformat(hist.tolist(), width=80, compact=True) - + "\n" - ) - except Exception as e: - f.write(f"\nHistogram computation failed: {str(e)}\n") - else: - f.write("Skipping additional statistics due to NaN/Inf values\n") - - f.write("\nArray contents:\n") - if arr.size <= 64: - formatted = pformat(arr.tolist(), width=80, compact=True) - f.write(formatted + "\n") - else: - f.write("\nFirst 5 elements:\n") - f.write( - pformat(arr.flatten()[:5].tolist(), width=80, compact=True) - + "\n" - ) - f.write("\nLast 5 elements:\n") - f.write( - pformat(arr.flatten()[-5:].tolist(), width=80, compact=True) - + "\n" - ) - - self.dump_id += 1 + mode = pd.Series(arr.flatten()).mode().iloc[0] + f.write(f"- Mode: {mode}\n") + except: + f.write("- Mode: Unable to compute\n") + + if np.issubdtype(arr.dtype, np.number): + try: + hist, bins = np.histogram(arr.flatten(), bins="auto") + f.write("\nHistogram:\n") + f.write( + "Bins: " + + pformat(bins.tolist(), width=80, compact=True) + + "\n" + ) + f.write( + "Counts: " + + pformat(hist.tolist(), width=80, compact=True) + + "\n" + ) + except Exception as e: + f.write(f"\nHistogram computation failed: {str(e)}\n") + else: + f.write( + "Skipping additional statistics due to NaN/Inf values\n" + ) + + f.write("\nArray contents:\n") + if arr.size <= 64: + formatted = pformat(arr.tolist(), width=80, compact=True) + f.write(formatted + "\n") + else: + f.write("\nFirst 5 elements:\n") + f.write( + pformat(arr.flatten()[:5].tolist(), width=80, compact=True) + + "\n" + ) + f.write("\nLast 5 elements:\n") + f.write( + pformat(arr.flatten()[-5:].tolist(), width=80, compact=True) + + "\n" + ) + + self.dump_id += 1 # Create single instance From 0660b07de5e8534df73d07913c986def97696d0d Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Mon, 16 Dec 2024 14:35:42 -0800 Subject: [PATCH 19/39] Add unit test summary jobs to sharktank and shortfin workflows. (#695) Progress on https://github.com/nod-ai/shark-ai/issues/357. This gives us a single job in each unit test workflow to set as a "required check" that will block merging, replacing the current single `Unit Tests (3.11, 2.3.0, ubuntu-24.04)` check: ![image](https://github.com/user-attachments/assets/f8cc1101-606a-44ac-b467-f2d80fc45357) I'm also switching ci-libshortfin.yml to run on all changes, which is a requirement for making a check required. We could still conditionally skip jobs based on paths modified... just not using GitHub's built in path filtering. --- .github/workflows/ci-libshortfin.yml | 25 +++++++++++++++++++------ .github/workflows/ci-sharktank.yml | 20 ++++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci-libshortfin.yml b/.github/workflows/ci-libshortfin.yml index 1bc1c913f..0e0982803 100644 --- a/.github/workflows/ci-libshortfin.yml +++ b/.github/workflows/ci-libshortfin.yml @@ -9,15 +9,9 @@ name: CI - shortfin on: workflow_dispatch: pull_request: - paths: - - '.github/workflows/ci-libshortfin.yml' - - 'shortfin/**' push: branches: - main - paths: - - '.github/workflows/ci-libshortfin.yml' - - 'shortfin/**' permissions: contents: read @@ -152,3 +146,22 @@ jobs: run: | ctest --timeout 30 --output-on-failure --test-dir build pytest -s --durations=10 + + # Depends on all other jobs to provide an aggregate job status. + ci_libshortfin_summary: + if: always() + runs-on: ubuntu-24.04 + needs: + - build-and-test + steps: + - name: Getting failed jobs + run: | + echo '${{ toJson(needs) }}' + FAILED_JOBS="$(echo '${{ toJson(needs) }}' \ + | jq --raw-output \ + 'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \ + )" + if [[ "${FAILED_JOBS}" != "" ]]; then + echo "The following jobs failed: ${FAILED_JOBS}" + exit 1 + fi diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 8e41c68b8..f3d47595c 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -191,3 +191,23 @@ jobs: run: | pytest -v sharktank/ -m punet_quick \ --durations=0 + + # Depends on other jobs to provide an aggregate job status. + # TODO(#584): move test_with_data and test_integration to a pkgci integration test workflow? + ci_sharktank_summary: + if: always() + runs-on: ubuntu-24.04 + needs: + - test + steps: + - name: Getting failed jobs + run: | + echo '${{ toJson(needs) }}' + FAILED_JOBS="$(echo '${{ toJson(needs) }}' \ + | jq --raw-output \ + 'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \ + )" + if [[ "${FAILED_JOBS}" != "" ]]; then + echo "The following jobs failed: ${FAILED_JOBS}" + exit 1 + fi From d1980c7e842b0c6d4ddc2e730b0e6654a9ce59ac Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 16 Dec 2024 23:56:13 +0100 Subject: [PATCH 20/39] [sharktank] Remove 'torch' from deps and warn instead (#706) Instead of enforcing the installation for 'torch' as a dependency, error if 'torch' cannot be imported and point the user to how to install. Co-authored-by: Scott Todd --- docs/user_guide.md | 13 ++++++++++--- sharktank/requirements.txt | 4 ---- sharktank/sharktank/__init__.py | 10 ++++++++++ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/docs/user_guide.md b/docs/user_guide.md index a0415eb63..d3ef192e0 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -34,13 +34,20 @@ Setup your Python environment with the following commands: # Set up a virtual environment to isolate packages from other envs. python3.11 -m venv 3.11.venv source 3.11.venv/bin/activate +``` + +## Install SHARK and its dependencies + +First install a torch version that fulfills your needs: -# Optional: faster installation of torch with just CPU support. -# See other options at https://pytorch.org/get-started/locally/ +```bash +# Fast installation of torch with just CPU support. pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu ``` -## Install SHARK and its dependencies +For other options, see https://pytorch.org/get-started/locally/. + +Next install shark-ai: ```bash pip install shark-ai[apps] diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt index 8a2a8ea3b..70780c346 100644 --- a/sharktank/requirements.txt +++ b/sharktank/requirements.txt @@ -9,10 +9,6 @@ huggingface-hub==0.22.2 transformers==4.40.0 datasets -# It is expected that you have installed a PyTorch version/variant specific -# to your needs, so we only include a minimum version spec. -torch>=2.3.0 - # Serving deps. fastapi>=0.112.2 uvicorn>=0.30.6 diff --git a/sharktank/sharktank/__init__.py b/sharktank/sharktank/__init__.py index a85ba359d..c0eb89810 100644 --- a/sharktank/sharktank/__init__.py +++ b/sharktank/sharktank/__init__.py @@ -3,3 +3,13 @@ # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import importlib.util + +msg = """No module named 'torch'. Follow https://pytorch.org/get-started/locally/#start-locally to install 'torch'. +For example, on Linux to install with CPU support run: + pip3 install torch --index-url https://download.pytorch.org/whl/cpu +""" + +if spec := importlib.util.find_spec("torch") is None: + raise ModuleNotFoundError(msg) From b151ffa89f1eaa068af0bd7eedea3c3b36b22320 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 17 Dec 2024 00:50:20 +0100 Subject: [PATCH 21/39] [shark-ai] Include sharktank in the meta package (#701) --- .../python_deploy/write_requirements.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/build_tools/python_deploy/write_requirements.py b/build_tools/python_deploy/write_requirements.py index 224e01fd0..5cdea0ff3 100755 --- a/build_tools/python_deploy/write_requirements.py +++ b/build_tools/python_deploy/write_requirements.py @@ -53,8 +53,8 @@ def write_requirements(requirements): metapackage_version = load_version_info(VERSION_FILE_LOCAL) PACKAGE_VERSION = metapackage_version.get("package-version") -# sharktank_version = load_version_info(VERSION_FILE_SHARKTANK) -# SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version") +sharktank_version = load_version_info(VERSION_FILE_SHARKTANK) +SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version") shortfin_version = load_version_info(VERSION_FILE_SHORTFIN) SHORTFIN_PACKAGE_VERSION = shortfin_version.get("package-version") @@ -65,13 +65,12 @@ def write_requirements(requirements): requirements = "" for package in stable_packages_list: requirements += package + "\n" - # TODO: Include sharktank as a dependencies of future releases - # requirements = ( - # "sharktank==" - # + Version(SHARKTANK_PACKAGE_VERSION).base_version - # + args.version_suffix - # + "\n" - # ) + requirements = ( + "sharktank==" + + Version(SHARKTANK_PACKAGE_VERSION).base_version + + args.version_suffix + + "\n" + ) requirements += ( "shortfin==" + Version(SHORTFIN_PACKAGE_VERSION).base_version @@ -89,10 +88,9 @@ def write_requirements(requirements): requirements = "" for package in stable_packages_list: requirements += package + "==" + STABLE_VERSION_TO_PIN + "\n" - # TODO: Include sharktank as a dependencies of future releases - # requirements += ( - # "sharktank==" + Version(SHARKTANK_PACKAGE_VERSION).base_version + "\n" - # ) + requirements += ( + "sharktank==" + Version(SHARKTANK_PACKAGE_VERSION).base_version + "\n" + ) requirements += "shortfin==" + Version(SHORTFIN_PACKAGE_VERSION).base_version write_requirements(requirements) From 83437b5ca890c85fb6ea81760833b5025f9c2553 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Tue, 17 Dec 2024 06:33:33 -0800 Subject: [PATCH 22/39] Add Flux transformer export for easier use outside of tests (#700) Adapt the model to accept parameters as structured in the HF repo. Make Punet parameters importation from HF more general to serve other models as well. When downloading a dataset from Hugging Face make it return the local location of all downloaded files including extras, not just the "leading" file. Add sample_inputs method to the BaseLayer interface to help standardize exportation. Introduce a standard export function for static-sized models. --- .github/workflows/ci-sharktank.yml | 10 +- sharktank/conftest.py | 9 ++ .../models/punet/integration_test.py | 2 +- sharktank/sharktank/export.py | 53 +++++++- sharktank/sharktank/layers/base.py | 24 +++- sharktank/sharktank/layers/mmdit.py | 25 ++-- sharktank/sharktank/layers/testing.py | 16 +-- sharktank/sharktank/models/flux/export.py | 49 +++++++ sharktank/sharktank/models/flux/flux.py | 122 ++++++++++++++---- .../punet => }/tools/import_hf_dataset.py | 56 +++++--- sharktank/sharktank/utils/cli.py | 14 +- sharktank/sharktank/utils/hf_datasets.py | 80 ++++++++++-- sharktank/tests/models/flux/flux_test.py | 91 +++++++------ sharktank/tests/models/llama/prefill_tests.py | 4 +- 14 files changed, 417 insertions(+), 138 deletions(-) create mode 100644 sharktank/sharktank/models/flux/export.py rename sharktank/sharktank/{models/punet => }/tools/import_hf_dataset.py (61%) diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index f3d47595c..d46efef09 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -136,15 +136,19 @@ jobs: pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ - name: Run tests - # TODO: unify with-t5-data and with-clip-data flags into a single flag - # and make it possible to run only tests that require data. + # TODO: unify with-*-data flags into a single flag and make it possible to run + # only tests that require data. + # We would still want the separate flags as we may endup with data being + # scattered on different CI machines. run: | source ${VENV_DIR}/bin/activate pytest \ - --with-clip-data \ + --with-clip-data \ + --with-flux-data \ --with-t5-data \ sharktank/tests/models/clip/clip_test.py \ sharktank/tests/models/t5/t5_test.py \ + sharktank/tests/models/flux/flux_test.py \ --durations=0 diff --git a/sharktank/conftest.py b/sharktank/conftest.py index 9d6257513..d7118893a 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -97,6 +97,15 @@ def pytest_addoption(parser): "code. The user is expected to provide the data" ), ) + parser.addoption( + "--with-flux-data", + action="store_true", + default=False, + help=( + "Enable tests that use Flux data like models that is not a part of the source " + "code. The user is expected to provide the data" + ), + ) parser.addoption( "--with-t5-data", action="store_true", diff --git a/sharktank/integration/models/punet/integration_test.py b/sharktank/integration/models/punet/integration_test.py index 754a54311..2ebb9e155 100644 --- a/sharktank/integration/models/punet/integration_test.py +++ b/sharktank/integration/models/punet/integration_test.py @@ -67,7 +67,7 @@ def download(filename): @pytest.fixture(scope="module") def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir): - from sharktank.models.punet.tools import import_hf_dataset + from sharktank.tools import import_hf_dataset dataset = temp_dir / "sdxl_fp16_dataset.irpa" import_hf_dataset.main( diff --git a/sharktank/sharktank/export.py b/sharktank/sharktank/export.py index 0a1c6940d..b54978e8b 100644 --- a/sharktank/sharktank/export.py +++ b/sharktank/sharktank/export.py @@ -4,11 +4,14 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Callable, Any +from typing import Callable, Optional, Any import torch +from os import PathLike +import iree.turbine.aot as aot from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder from torch.utils._pytree import tree_structure, tree_unflatten, tree_flatten from .types.tensors import ShardedTensor +from .layers import BaseLayer from torch.utils._pytree import PyTree, _is_leaf import functools @@ -172,3 +175,51 @@ def flat_fn(*args, **kwargs): ) assert False, "TODO: implement the case when not using an FxProgramsBuilder" + + +def export_static_model_mlir( + model: BaseLayer, + output_path: PathLike, + function_batch_size_pairs: Optional[dict[Optional[str], list[int]]] = None, + batch_sizes: Optional[list[int]] = None, +): + """Export a model with no dynamic dimensions. + + For the set of provided function name batch sizes pair, the resulting MLIR will + have function names with the below format. + ``` + _bs + ``` + + If `batch_sizes` is given then it defaults to a single function with named + "forward". + + The model is required to implement method `sample_inputs`. + """ + + assert not (function_batch_size_pairs is not None and batch_sizes is not None) + + if batch_sizes is not None: + function_batch_size_pairs = {None: batch_sizes} + + if function_batch_size_pairs is None and batch_sizes is None: + function_batch_size_pairs = {None: batch_sizes} + + fxb = FxProgramsBuilder(model) + + for function, batch_sizes in function_batch_size_pairs.items(): + for batch_size in batch_sizes: + args, kwargs = model.sample_inputs(batch_size, function) + + @fxb.export_program( + name=f"{function or 'forward'}_bs{batch_size}", + args=args, + kwargs=kwargs, + dynamic_shapes=None, + strict=False, + ) + def _(model, **kwargs): + return model(**kwargs) + + output = aot.export(fxb) + output.save_mlir(output_path) diff --git a/sharktank/sharktank/layers/base.py b/sharktank/sharktank/layers/base.py index 11a21f885..8f74c239d 100644 --- a/sharktank/sharktank/layers/base.py +++ b/sharktank/sharktank/layers/base.py @@ -4,15 +4,12 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict - +from typing import Dict, Optional +from collections import OrderedDict import torch import torch.nn as nn -from ..types import ( - InferenceTensor, - Theta, -) +from ..types import InferenceTensor, Theta, AnyTensor from ..utils import debugging __all__ = [ @@ -56,6 +53,21 @@ def assert_not_nan(self, *ts: torch.Tensor): if torch.isnan(t).any(): raise AssertionError(f"Tensor contains nans! {t}") + def sample_inputs( + self, batch_size: int = 1, function: Optional[str] = None + ) -> tuple[tuple[AnyTensor], OrderedDict[str, AnyTensor]]: + """Return sample inputs that can be used to run the function from the model. + If function is None then layer is treated as the callable. + E.g. + ``` + args, kwargs = model.sample_inputs() + model(*args, **kwargs) + ``` + + One purpose of this method is to standardize exportation of models to MLIR. + """ + raise NotImplementedError() + class ThetaLayer(BaseLayer): "Base class for layers that derive parameters from a Theta object." diff --git a/sharktank/sharktank/layers/mmdit.py b/sharktank/sharktank/layers/mmdit.py index 1557883ae..1c398f608 100644 --- a/sharktank/sharktank/layers/mmdit.py +++ b/sharktank/sharktank/layers/mmdit.py @@ -55,11 +55,15 @@ def __init__(self, theta, num_heads: int): self.add_module("img_attn_qkv", LinearLayer(theta("img_attn.qkv"))) self.add_module( "img_attn_norm_q", - RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6), + RMSNormLayer( + theta("img_attn.norm.query_norm"), weight_name="scale", epsilon=1e-6 + ), ) self.add_module( "img_attn_norm_k", - RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6), + RMSNormLayer( + theta("img_attn.norm.key_norm"), weight_name="scale", epsilon=1e-6 + ), ) self.add_module("img_attn_proj", LinearLayer(theta("img_attn.proj"))) @@ -70,11 +74,15 @@ def __init__(self, theta, num_heads: int): self.add_module("txt_attn_qkv", LinearLayer(theta("txt_attn.qkv"))) self.add_module( "txt_attn_norm_q", - RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6), + RMSNormLayer( + theta("txt_attn.norm.query_norm"), weight_name="scale", epsilon=1e-6 + ), ) self.add_module( "txt_attn_norm_k", - RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6), + RMSNormLayer( + theta("txt_attn.norm.key_norm"), weight_name="scale", epsilon=1e-6 + ), ) self.add_module("txt_attn_proj", LinearLayer(theta("txt_attn.proj"))) @@ -151,14 +159,15 @@ def __init__(self, theta, num_heads: int): super().__init__(theta) self.num_heads = num_heads - self.add_module("mod", ModulationLayer(theta("mod"), double=False)) + self.add_module("mod", ModulationLayer(theta("modulation"), double=False)) self.add_module( - "attn_norm_q", RMSNormLayer(theta("attn.norm.query_norm"), epsilon=1e-6) + "attn_norm_q", + RMSNormLayer(theta("norm.query_norm"), weight_name="scale", epsilon=1e-6), ) self.add_module( - "attn_norm_k", RMSNormLayer(theta("attn.norm.key_norm"), epsilon=1e-6) + "attn_norm_k", + RMSNormLayer(theta("norm.key_norm"), weight_name="scale", epsilon=1e-6), ) - self.add_module("attn_proj", LinearLayer(theta("attn.proj"))) self.add_module("linear1", LinearLayer(theta("linear1"))) self.add_module("linear2", LinearLayer(theta("linear2"))) diff --git a/sharktank/sharktank/layers/testing.py b/sharktank/sharktank/layers/testing.py index 74ba49624..6ea089bb7 100644 --- a/sharktank/sharktank/layers/testing.py +++ b/sharktank/sharktank/layers/testing.py @@ -65,10 +65,10 @@ def make_mmdit_double_block_random_theta( mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size) return Theta( { - "img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), - "img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), "img_attn.proj.bias": DefaultPrimitiveTensor( @@ -101,10 +101,10 @@ def make_mmdit_double_block_random_theta( "img_mod.lin.weight": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) ), - "txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), - "txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), "txt_attn.proj.bias": DefaultPrimitiveTensor( @@ -155,10 +155,10 @@ def make_mmdit_single_block_random_theta( mlp_hidden_size3 = int((2 * mlp_ratio - 1) * hidden_size) return Theta( { - "attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), - "attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels,), dtype=dtype) ), "attn.proj.bias": DefaultPrimitiveTensor( @@ -179,10 +179,10 @@ def make_mmdit_single_block_random_theta( "linear2.weight": DefaultPrimitiveTensor( data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) ), - "mod.lin.bias": DefaultPrimitiveTensor( + "modulation.lin.bias": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size,), dtype=dtype) ), - "mod.lin.weight": DefaultPrimitiveTensor( + "modulation.lin.weight": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) ), } diff --git a/sharktank/sharktank/models/flux/export.py b/sharktank/sharktank/models/flux/export.py new file mode 100644 index 000000000..fae3a5362 --- /dev/null +++ b/sharktank/sharktank/models/flux/export.py @@ -0,0 +1,49 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from os import PathLike + +from ...export import export_static_model_mlir +from ...tools.import_hf_dataset import import_hf_dataset +from .flux import FluxModelV1, FluxParams +from ...types import Dataset +from ...utils.hf_datasets import get_dataset + +flux_transformer_default_batch_sizes = [4] + + +def export_flux_transformer_model_mlir( + model: FluxModelV1, + output_path: PathLike, + batch_sizes: list[int] = flux_transformer_default_batch_sizes, +): + export_static_model_mlir(model, output_path=output_path, batch_sizes=batch_sizes) + + +def export_flux_transformer_from_hugging_face( + repo_id: str, + mlir_output_path: PathLike, + parameters_output_path: PathLike, + batch_sizes: list[int] = flux_transformer_default_batch_sizes, +): + hf_dataset = get_dataset( + repo_id, + ).download() + + import_hf_dataset( + config_json_path=hf_dataset["config"][0], + param_paths=hf_dataset["parameters"], + output_irpa_file=parameters_output_path, + ) + + dataset = Dataset.load(parameters_output_path) + model = FluxModelV1( + theta=dataset.root_theta, + params=FluxParams.from_hugging_face_properties(dataset.properties), + ) + export_flux_transformer_model_mlir( + model, output_path=mlir_output_path, batch_sizes=batch_sizes + ) diff --git a/sharktank/sharktank/models/flux/flux.py b/sharktank/sharktank/models/flux/flux.py index ac63f47a0..d99b14ad4 100644 --- a/sharktank/sharktank/models/flux/flux.py +++ b/sharktank/sharktank/models/flux/flux.py @@ -9,6 +9,8 @@ https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py """ +from typing import Any, Optional +from collections import OrderedDict import math from dataclasses import dataclass import torch @@ -45,6 +47,46 @@ class FluxParams: qkv_bias: bool guidance_embed: bool + @staticmethod + def from_hugging_face_properties(properties: dict[str, Any]) -> "FluxParams": + p = properties["hparams"] + + in_channels = p["in_channels"] + out_channels = p["in_channels"] + vec_in_dim = p["pooled_projection_dim"] + context_in_dim = p["joint_attention_dim"] + mlp_ratio = 4.0 + hidden_size = vec_in_dim * int(mlp_ratio) + num_heads = p["num_attention_heads"] + depth = p["num_layers"] + depth_single_blocks = p["num_single_layers"] + + # TODO: figure out relation between hidden_size, num_heads and + # attention_head_dim. + # diffusers.FluxTransformer2DModel also hardcodes this. + axes_dim = [16, 56, 56] + assert sum(axes_dim) == p["attention_head_dim"] + + theta = 10_000 + qkv_bias = True + guidance_embed = p["guidance_embeds"] + + return FluxParams( + in_channels=in_channels, + out_channels=out_channels, + vec_in_dim=vec_in_dim, + context_in_dim=context_in_dim, + mlp_ratio=mlp_ratio, + hidden_size=hidden_size, + num_heads=num_heads, + depth=depth, + depth_single_blocks=depth_single_blocks, + axes_dim=axes_dim, + theta=theta, + qkv_bias=qkv_bias, + guidance_embed=guidance_embed, + ) + class FluxModelV1(ThetaLayer): """FluxModel adapted from Black Forest Lab's implementation.""" @@ -71,16 +113,12 @@ def __init__(self, theta: Theta, params: FluxParams): dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim ) self.add_module("img_in", LinearLayer(theta("img_in"))) - # TODO: Refactor this pattern to an MLPEmbedder like src implementatio - self.add_module("time_in_0", LinearLayer(theta("time_in.0"))) - self.add_module("time_in_1", LinearLayer(theta("time_in.1"))) - self.add_module("vector_in_0", LinearLayer(theta("vector_in.0"))) - self.add_module("vector_in_1", LinearLayer(theta("vector_in.1"))) + self.add_module("time_in", MLPEmbedder(theta("time_in"))) + self.add_module("vector_in", MLPEmbedder(theta("vector_in"))) self.guidance = False if params.guidance_embed: self.guidance = True - self.add_module("guidance_in_0", LinearLayer(theta("guidance_in.0"))) - self.add_module("guidance_in_1", LinearLayer(theta("guidance_in.1"))) + self.add_module("guidance_in", MLPEmbedder(theta("guidance_in"))) self.add_module("txt_in", LinearLayer(theta("txt_in"))) self.double_blocks = nn.ModuleList( @@ -104,8 +142,8 @@ def __init__(self, theta: Theta, params: FluxParams): ) self.add_module( - "last_layer", - LastLayer(theta("last_layer")), + "final_layer", + LastLayer(theta("final_layer")), ) def forward( @@ -123,23 +161,14 @@ def forward( # running on sequences img img = self.img_in(img) - time_in_0 = self.time_in_0(timestep_embedding(timesteps, 256)) - time_in_silu = ops.elementwise(F.silu, time_in_0) - vec = self.time_in_1(time_in_silu) + vec = self.time_in(timestep_embedding(timesteps, 256)) if self.guidance: if guidance is None: raise ValueError( "Didn't get guidance strength for guidance distilled model." ) - guidance_inp = timestep_embedding(guidance, 256) - guidance0 = self.guidance_in0(guidance_inp) - guidance_silu = ops.elementwise(F.silu, guidance0) - guidance_out = self.guidance_in1(guidance_silu) - vec = vec + self.guidance_in(guidance_out) - vector_in_0 = self.vector_in_0(y) - vector_in_silu = ops.elementwise(F.silu, vector_in_0) - vector_in_1 = self.vector_in_1(vector_in_silu) - vec = vec + vector_in_1 + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) txt = self.txt_in(txt) @@ -154,9 +183,36 @@ def forward( img = block(img, vec=vec, pe=pe) img = img[:, txt.shape[1] :, ...] - img = self.last_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img + def sample_inputs( + self, batch_size: int = 1, function: Optional[str] = None + ) -> tuple[tuple[AnyTensor], OrderedDict[str, AnyTensor]]: + if not (function is None or function == "forward"): + raise ValueError(f'Only function "forward" is supported. Got "{function}"') + + # TODO: do not hardcode these but derive the required shapes from the config. + img = torch.rand([batch_size, 1024, 64]) + img_ids = torch.rand([batch_size, 1024, 3]) + txt = torch.rand([batch_size, 512, 4096]) + txt_ids = torch.rand([batch_size, 512, 3]) + timesteps = torch.rand([batch_size]) + y = torch.rand([batch_size, 768]) + + args = tuple() + kwargs = OrderedDict( + ( + ("img", img), + ("img_ids", img_ids), + ("txt", txt), + ("txt_ids", txt_ids), + ("timesteps", timesteps), + ("y", y), + ) + ) + return args, kwargs + ################################################################################ # Layers @@ -216,6 +272,18 @@ def rope(pos: AnyTensor, dim: int, theta: int) -> AnyTensor: return out.float() +class MLPEmbedder(ThetaLayer): + def __init__(self, theta: Theta): + super().__init__(theta) + self.in_layer = LinearLayer(theta("in_layer")) + self.out_layer = LinearLayer(theta("out_layer")) + + def forward(self, x: AnyTensor) -> AnyTensor: + x = self.in_layer(x) + x = ops.elementwise(torch.nn.functional.silu, x) + return self.out_layer(x) + + class EmbedND(torch.nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() @@ -239,13 +307,15 @@ def __init__( theta: Theta, ): super().__init__(theta) - self.add_module("outlinear", LinearLayer(theta("outlinear"))) - self.add_module("ada_linear", LinearLayer(theta("ada_linear"))) + self.add_module( + "adaLN_modulation_linear", LinearLayer(theta("adaLN_modulation.1")) + ) + self.add_module("linear", LinearLayer(theta("linear"))) def forward(self, x: AnyTensor, vec: AnyTensor) -> AnyTensor: silu = ops.elementwise(F.silu, vec) - lin = self.ada_linear(silu) + lin = self.adaLN_modulation_linear(silu) shift, scale = lin.chunk(2, dim=1) x = (1 + scale[:, None, :]) * layer_norm(x) + shift[:, None, :] - x = self.outlinear(x) + x = self.linear(x) return x diff --git a/sharktank/sharktank/models/punet/tools/import_hf_dataset.py b/sharktank/sharktank/tools/import_hf_dataset.py similarity index 61% rename from sharktank/sharktank/models/punet/tools/import_hf_dataset.py rename to sharktank/sharktank/tools/import_hf_dataset.py index 0afa5222d..8b8feed9f 100644 --- a/sharktank/sharktank/models/punet/tools/import_hf_dataset.py +++ b/sharktank/sharktank/tools/import_hf_dataset.py @@ -14,21 +14,31 @@ Usage: python -m sharktank.models.punet.import_hf_dataset \ --output-irpa-file ~/models/punet/punet_fp16.irpa \ - --config-json ~/models/stable-diffusion-xl-base-1.0/unet/config.json + --config-json ~/models/stable-diffusion-xl-base-1.0/unet/config.json \ + --params diffusion_pytorch_model.fp16.safetensors The resulting dataset has all tensors as nested in the original model. Properties are separated into a "meta" dict (for "_" prefixed props) and an "hparams" dict. """ +from typing import Optional +from os import PathLike import json from pathlib import Path import sys +import logging -from ....types import * +from ..types import * +logger = logging.getLogger(__name__) -def import_hf_config(config_json_path: Path, params_path: Path) -> Dataset: + +def import_hf_dataset( + config_json_path: PathLike, + param_paths: list[PathLike], + output_irpa_file: Optional[PathLike] = None, +) -> Optional[Dataset]: import safetensors with open(config_json_path, "rb") as f: @@ -37,22 +47,28 @@ def import_hf_config(config_json_path: Path, params_path: Path) -> Dataset: meta_params = {k: v for k, v in config_json.items() if k.startswith("_")} hparams = {k: v for k, v in config_json.items() if not k.startswith("_")} - with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: - tensors = [ - DefaultPrimitiveTensor(name=name, data=st.get_tensor(name)) - for name in st.keys() - ] + for params_path in param_paths: + with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: + tensors = [ + DefaultPrimitiveTensor(name=name, data=st.get_tensor(name)) + for name in st.keys() + ] theta = Theta(tensors) props = { "meta": meta_params, "hparams": hparams, } - return Dataset(props, theta) + dataset = Dataset(props, theta) + + if output_irpa_file is None: + return dataset + + dataset.save(output_irpa_file, io_report_callback=logger.info) -def main(argv): - from ....utils import cli +def main(argv: list[str]): + from ..utils import cli parser = cli.create_parser() cli.add_output_dataset_options(parser) @@ -62,18 +78,22 @@ def main(argv): parser.add_argument( "--params", type=Path, + nargs="+", default=Path("diffusion_pytorch_model.fp16.safetensors"), - help="Parameter file name, relative to config.json", + help="Parameter file name(s), relative to config.json", ) args = cli.parse(parser, args=argv) config_json_path: Path = args.config_json - params_path: Path = args.params - if not params_path.is_absolute(): - params_path = config_json_path.parent / params_path - - dataset = import_hf_config(config_json_path, params_path) - dataset.save(args.output_irpa_file, io_report_callback=print) + param_paths: list[Path] = args.params + param_paths = [ + path if path.is_absolute() else config_json_path.parent / path + for path in param_paths + ] + + import_hf_dataset( + config_json_path, param_paths, output_irpa_file=args.output_irpa_file + ) if __name__ == "__main__": diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 9fefeb66f..b4b405dca 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -100,7 +100,7 @@ def add_tokenizer_options(parser: argparse.ArgumentParser): ) -def get_input_data_files(args) -> Optional[dict[str, Path]]: +def get_input_data_files(args) -> Optional[dict[str, list[Path]]]: """Gets data files given the input arguments. Keys may contain: @@ -112,9 +112,9 @@ def get_input_data_files(args) -> Optional[dict[str, Path]]: dataset = hf_datasets.get_dataset(args.hf_dataset).download() return dataset elif args.gguf_file is not None: - return {"gguf": args.gguf_file} + return {"gguf": [args.gguf_file]} elif args.irpa_file is not None: - return {"irpa": args.irpa_file} + return {"irpa": [args.irpa_file]} def get_input_dataset(args) -> Dataset: @@ -124,10 +124,10 @@ def get_input_dataset(args) -> Dataset: """ data_files = get_input_data_files(args) if "gguf" in data_files: - return Dataset.load(data_files["gguf"], file_type="gguf") + return Dataset.load(data_files["gguf"][0], file_type="gguf") if "irpa" in data_files: - return Dataset.load(data_files["irpa"], file_type="irpa") + return Dataset.load(data_files["irpa"][0], file_type="irpa") raise ValueError(f'Dataset format unsupported. Must be "gguf" or "irpa".') @@ -142,7 +142,7 @@ def get_tokenizer(args) -> tokenizer.InferenceTokenizer: return tokenizer.fake_tokenizer() if args.tokenizer_config_json is not None: - data_files = {"tokenizer_config.json": args.tokenizer_config_json} + data_files = {"tokenizer_config.json": [args.tokenizer_config_json]} else: data_files = get_input_data_files(args) @@ -150,7 +150,7 @@ def get_tokenizer(args) -> tokenizer.InferenceTokenizer: if tokenizer_type is None: if "tokenizer_config.json" in data_files: return tokenizer.load_tokenizer( - data_files["tokenizer_config.json"].parent, + data_files["tokenizer_config.json"][0].parent, tokenizer_type="transformers", ) else: diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py index c6a799404..6893b637a 100644 --- a/sharktank/sharktank/utils/hf_datasets.py +++ b/sharktank/sharktank/utils/hf_datasets.py @@ -33,16 +33,26 @@ class RemoteFile: filename: str extra_filenames: Sequence[str] = () - def download(self, *, local_dir: Optional[Path] = None) -> Path: - for extra_filename in self.extra_filenames: - hf_hub_download( - repo_id=self.repo_id, filename=extra_filename, local_dir=local_dir - ) - return Path( - hf_hub_download( - repo_id=self.repo_id, filename=self.filename, local_dir=local_dir + def download(self, *, local_dir: Optional[Path] = None) -> list[Path]: + res = [] + res.append( + Path( + hf_hub_download( + repo_id=self.repo_id, filename=self.filename, local_dir=local_dir + ) ) ) + for extra_filename in self.extra_filenames: + res.append( + Path( + hf_hub_download( + repo_id=self.repo_id, + filename=extra_filename, + local_dir=local_dir, + ) + ) + ) + return res @dataclass @@ -59,7 +69,7 @@ def alias_to(self, to_name: str) -> "Dataset": alias_dataset(self.name, to_name) return self - def download(self, *, local_dir: Optional[Path] = None) -> Dict[str, Path]: + def download(self, *, local_dir: Optional[Path] = None) -> Dict[str, list[Path]]: return {f.file_id: f.download(local_dir=local_dir) for f in self.files} @@ -363,6 +373,54 @@ def alias_dataset(from_name: str, to_name: str): ), ) +# The Flux transformer is in 2 formats. +# This is used in diffusers.FluxTransformer2DModel +Dataset( + "black-forest-labs/FLUX.1-schnell/transformer", + ( + RemoteFile( + "config", + "black-forest-labs/FLUX.1-schnell", + "transformer/config.json", + ), + RemoteFile( + "parameters", + "black-forest-labs/FLUX.1-schnell", + "transformer/diffusion_pytorch_model-00001-of-00003.safetensors", + extra_filenames=[ + "transformer/diffusion_pytorch_model-00002-of-00003.safetensors", + "transformer/diffusion_pytorch_model-00003-of-00003.safetensors", + ], + ), + RemoteFile( + "parameters-index", + "black-forest-labs/FLUX.1-schnell", + "transformer/diffusion_pytorch_model.safetensors.index.json", + ), + ), +) + +# The Flux transformer is in 2 formats. +# This is used in the Black Forest's Flux repo. +# https://github.com/black-forest-labs/flux +# We have based our implementation on that. +Dataset( + "black-forest-labs/FLUX.1-schnell/black-forest-labs-transformer", + ( + RemoteFile( + "config", + "black-forest-labs/FLUX.1-schnell", + "transformer/config.json", + ), + RemoteFile( + "parameters", + "black-forest-labs/FLUX.1-schnell", + "flux1-schnell.safetensors", + ), + ), +) + + ################################################################################ # Tool entrypoint ################################################################################ @@ -386,8 +444,8 @@ def main(): for dataset_name in args.dataset_name: print(f"Downloading dataset {dataset_name}") ds = get_dataset(dataset_name).download(local_dir=args.local_dir) - for key, path in ds.items(): - print(f" {key}: {path}") + for key, paths in ds.items(): + print(f" {key}: {paths}") if __name__ == "__main__": diff --git a/sharktank/tests/models/flux/flux_test.py b/sharktank/tests/models/flux/flux_test.py index ea80c7b42..fc4d23251 100644 --- a/sharktank/tests/models/flux/flux_test.py +++ b/sharktank/tests/models/flux/flux_test.py @@ -5,18 +5,17 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import logging - -logging.basicConfig(level=logging.DEBUG) - import unittest - import torch - -from iree.turbine import aot +import pytest from sharktank.models.flux.flux import ( FluxModelV1, FluxParams, ) +from sharktank.models.flux.export import ( + export_flux_transformer_model_mlir, + export_flux_transformer_from_hugging_face, +) import sharktank.ops as ops from sharktank.layers.testing import ( make_rand_torch, @@ -24,11 +23,14 @@ from sharktank.types.tensors import DefaultPrimitiveTensor from sharktank.types.theta import Dataset, Theta from sharktank.utils.testing import TempDirTestBase +from sharktank.utils.hf_datasets import get_dataset + +logging.basicConfig(level=logging.DEBUG) +with_flux_data = pytest.mark.skipif("not config.getoption('with_flux_data')") # TODO: Refactor this to a function that generates random toy weights, possibly # to another file -dtype = torch.float32 in_channels = 64 in_channels2 = 128 hidden_size = 3072 @@ -45,7 +47,7 @@ out_channels = 64 -def make_random_theta(): +def make_random_theta(dtype: torch.dtype): return Theta( { "img_in.weight": DefaultPrimitiveTensor( # @@ -60,34 +62,34 @@ def make_random_theta(): "txt_in.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size,), dtype=dtype) ), - "time_in.0.weight": DefaultPrimitiveTensor( # + "time_in.in_layer.weight": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size, time_dim), dtype=dtype) ), - "time_in.0.bias": DefaultPrimitiveTensor( # + "time_in.in_layer.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size,), dtype=dtype) ), - "time_in.1.weight": DefaultPrimitiveTensor( # + "time_in.out_layer.weight": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) ), - "time_in.1.bias": DefaultPrimitiveTensor( # + "time_in.out_layer.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size,), dtype=dtype) ), - "vector_in.0.weight": DefaultPrimitiveTensor( # + "vector_in.in_layer.weight": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size, vec_dim), dtype=dtype) ), - "vector_in.0.bias": DefaultPrimitiveTensor( # + "vector_in.in_layer.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size,), dtype=dtype) ), - "vector_in.1.weight": DefaultPrimitiveTensor( # + "vector_in.out_layer.weight": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) ), - "vector_in.1.bias": DefaultPrimitiveTensor( # + "vector_in.out_layer.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size,), dtype=dtype) ), - "double_blocks.0.img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "double_blocks.0.img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), - "double_blocks.0.img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "double_blocks.0.img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), "double_blocks.0.img_attn.proj.bias": DefaultPrimitiveTensor( @@ -120,10 +122,10 @@ def make_random_theta(): "double_blocks.0.img_mod.lin.weight": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) ), - "double_blocks.0.txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "double_blocks.0.txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), - "double_blocks.0.txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "double_blocks.0.txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), "double_blocks.0.txt_attn.proj.bias": DefaultPrimitiveTensor( @@ -156,10 +158,10 @@ def make_random_theta(): "double_blocks.0.txt_mod.lin.weight": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) ), - "single_blocks.0.attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + "single_blocks.0.norm.key_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), - "single_blocks.0.attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + "single_blocks.0.norm.query_norm.scale": DefaultPrimitiveTensor( # data=make_rand_torch((in_channels2,), dtype=dtype) ), "single_blocks.0.attn.proj.bias": DefaultPrimitiveTensor( @@ -180,26 +182,26 @@ def make_random_theta(): "single_blocks.0.linear2.weight": DefaultPrimitiveTensor( data=make_rand_torch((hidden_size, mlp_hidden_size4), dtype=dtype) ), - "single_blocks.0.mod.lin.bias": DefaultPrimitiveTensor( + "single_blocks.0.modulation.lin.bias": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size,), dtype=dtype) ), - "single_blocks.0.mod.lin.weight": DefaultPrimitiveTensor( + "single_blocks.0.modulation.lin.weight": DefaultPrimitiveTensor( data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) ), - "last_layer.outlinear.weight": DefaultPrimitiveTensor( # + "final_layer.linear.weight": DefaultPrimitiveTensor( # data=make_rand_torch( (patch_size * patch_size * out_channels, hidden_size), dtype=dtype ) ), - "last_layer.outlinear.bias": DefaultPrimitiveTensor( # + "final_layer.linear.bias": DefaultPrimitiveTensor( # data=make_rand_torch( (patch_size * patch_size * out_channels,), dtype=dtype ) ), - "last_layer.ada_linear.weight": DefaultPrimitiveTensor( # + "final_layer.adaLN_modulation.1.weight": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size * 2, hidden_size), dtype=dtype) ), - "last_layer.ada_linear.bias": DefaultPrimitiveTensor( # + "final_layer.adaLN_modulation.1.bias": DefaultPrimitiveTensor( # data=make_rand_torch((hidden_size * 2,), dtype=dtype) ), } @@ -214,7 +216,8 @@ def setUp(self): self.num_heads = 24 self.batch_size = 5 - def testExport(self): + def testExportBfloat16SingleLayer(self): + dtype = torch.bfloat16 params = FluxParams( in_channels=64, out_channels=64, @@ -230,32 +233,26 @@ def testExport(self): qkv_bias=True, guidance_embed=False, ) - theta = make_random_theta() + theta = make_random_theta(dtype) theta = self.save_load_theta(theta) flux = FluxModelV1( theta=theta, params=params, ) - img = torch.rand([self.batch_size, 1024, 64]) - img_ids = torch.rand([self.batch_size, 1024, 3]) - txt = torch.rand([self.batch_size, 512, 4096]) - txt_ids = torch.rand([self.batch_size, 512, 3]) - timesteps = torch.rand([self.batch_size]) - y = torch.rand([self.batch_size, 768]) - - flux.forward(img, img_ids, txt, txt_ids, timesteps, y) - fxb = aot.FxProgramsBuilder(flux) - - @fxb.export_program( - name="flux", args=(img, img_ids, txt, txt_ids, timesteps, y), strict=False + export_flux_transformer_model_mlir( + flux, + output_path=self._temp_dir / "model.mlir", + batch_sizes=[self.batch_size], ) - def _(model, img, img_ids, txt, txt_ids, timesteps, y) -> torch.Tensor: - return model.forward(img, img_ids, txt, txt_ids, timesteps, y) - output = aot.export(fxb) - output.verify() - asm = str(output.mlir_module) + @with_flux_data + def testExportSchnellFromHuggingFace(self): + export_flux_transformer_from_hugging_face( + "black-forest-labs/FLUX.1-schnell/black-forest-labs-transformer", + mlir_output_path=self._temp_dir / "model.mlir", + parameters_output_path=self._temp_dir / "parameters.irpa", + ) def save_load_theta(self, theta: Theta): # Roundtrip to disk to avoid treating parameters as constants that would appear diff --git a/sharktank/tests/models/llama/prefill_tests.py b/sharktank/tests/models/llama/prefill_tests.py index 093ecdfc9..f7b456389 100644 --- a/sharktank/tests/models/llama/prefill_tests.py +++ b/sharktank/tests/models/llama/prefill_tests.py @@ -86,7 +86,7 @@ def setUp(self): self.data_files = hf_datasets.get_dataset( default_arguments["hf_dataset"] ).download(local_dir=Path(".")) - self.dataset = Dataset.load(self.data_files["gguf"], file_type="gguf") + self.dataset = Dataset.load(self.data_files["gguf"][0], file_type="gguf") self.tokenizer_config = tokenizer.load_tokenizer( default_arguments["tokenizer-config-json"].parent, tokenizer_type="transformers", @@ -138,7 +138,7 @@ def setUp(self): self.data_files = hf_datasets.get_dataset( default_arguments["hf_dataset"] ).download(local_dir=Path(".")) - self.dataset = Dataset.load(self.data_files["gguf"], file_type="gguf") + self.dataset = Dataset.load(self.data_files["gguf"][0], file_type="gguf") self.tokenizer_config = tokenizer.load_tokenizer( default_arguments["tokenizer-config-json"].parent, tokenizer_type="transformers", From d08b0e5dc4ecf4af364121933fbe0aa61031cc66 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 17 Dec 2024 09:39:10 -0800 Subject: [PATCH 23/39] [shortfin] Use custom manylinux dockerfile in build_linux_package.sh. (#709) Package builds started failing last night when the `latest` upstream manylinux dockerfile switched to gcc 14: https://github.com/nod-ai/shark-ai/actions/runs/12371699664/job/34528374484 ``` Running command Building wheel for shortfin (pyproject.toml) -- The C compiler identification is GNU 14.2.1 -- The CXX compiler identification is GNU 14.2.1 ... [325/365] Building CXX object src/shortfin/local/CMakeFiles/shortfin_local.dylib.objects.dir/device.cc.o FAILED: src/shortfin/local/CMakeFiles/shortfin_local.dylib.objects.dir/device.cc.o /opt/rh/gcc-toolset-14/root/usr/bin/c++ -DCPUINFO_SUPPORTED_PLATFORM=1 -DSPDLOG_COMPILED_LIB -DSPDLOG_FMT_EXTERNAL -DSPDLOG_SHARED_LIB -D_SHORTFIN_BUILDING_DYLIB -Dspdlog_EXPORTS -I/home/runner/work/shark-ai/shark-ai/c/shortfin/src -I/home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/src -I/home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/spdlog-src/include -I/home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/fmt-src/include -isystem /home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/shortfin_iree-src -isystem /home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/shortfin_iree-build -isystem /home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/shortfin_iree-src/runtime/src -isystem /home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/shortfin_iree-build/runtime/src -isystem /home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/shortfin_iree-src/third_party/cpuinfo/include -isystem /home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/shortfin_iree-build/runtime/src/iree/base/internal/flatcc -isystem /home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/shortfin_iree-src/third_party/flatcc/include -isystem /home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/shortfin_iree-build/runtime/src/iree/schemas -O3 -DNDEBUG -std=gnu++20 -flto=auto -fno-fat-lto-objects -fPIC -fvisibility=hidden -fvisibility-inlines-hidden -Wall -Werror -pthread -I/home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/shortfin_iree-src/third_party/flatcc/include/ -I/home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/shortfin_iree-src/third_party/flatcc/include/flatcc/reflection/ -MD -MT src/shortfin/local/CMakeFiles/shortfin_local.dylib.objects.dir/device.cc.o -MF src/shortfin/local/CMakeFiles/shortfin_local.dylib.objects.dir/device.cc.o.d -o src/shortfin/local/CMakeFiles/shortfin_local.dylib.objects.dir/device.cc.o -c /home/runner/work/shark-ai/shark-ai/c/shortfin/src/shortfin/local/device.cc In file included from /home/runner/work/shark-ai/shark-ai/c/shortfin/src/shortfin/local/device.cc:10: /home/runner/work/shark-ai/shark-ai/c/shortfin/build/cmake/default/_deps/fmt-src/include/fmt/ranges.h:211:49: error: self-comparison always evaluates to true [-Werror=tautological-compare] 211 | integer_sequence) -> std::true_type; | ~~~^~~~~ cc1plus: all warnings being treated as errors ``` This switches to our own downstream dockerfile, defined here: https://github.com/nod-ai/base-docker-images/blob/main/dockerfiles/manylinux_x86_64.Dockerfile, which is pinned to an older version of the base image (and thus gcc). Tested successfully here: https://github.com/ScottTodd/shark-ai/actions/runs/12378199850 --- shortfin/build_tools/build_linux_package.sh | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/shortfin/build_tools/build_linux_package.sh b/shortfin/build_tools/build_linux_package.sh index 91b944e51..db2463987 100755 --- a/shortfin/build_tools/build_linux_package.sh +++ b/shortfin/build_tools/build_linux_package.sh @@ -37,13 +37,18 @@ REPO_ROOT="$(cd "$THIS_DIR"/../../ && pwd)" SCRIPT_NAME="$(basename $0)" ARCH="$(uname -m)" -# Note: we can switch to https://github.com/nod-ai/base-docker-images as needed for extra deps. -MANYLINUX_DOCKER_IMAGE="${MANYLINUX_DOCKER_IMAGE:-quay.io/pypa/manylinux_2_28_${ARCH}:latest}" PYTHON_VERSIONS="${OVERRIDE_PYTHON_VERSIONS:-cp311-cp311 cp312-cp312 cp313-cp313}" OUTPUT_DIR="${OUTPUT_DIR:-${THIS_DIR}/wheelhouse}" CACHE_DIR="${CACHE_DIR:-}" SHORTFIN_ENABLE_TRACING="${SHORTFIN_ENABLE_TRACING:-ON}" +if [[ "${ARCH}" == "x86_64" ]]; then + MANYLINUX_DOCKER_IMAGE="${MANYLINUX_DOCKER_IMAGE:-ghcr.io/nod-ai/manylinux_x86_64@sha256:4acf83343706d1e37252d6001ded3c97a73bc38620580f855b4e65e35ddc5681}" +else + # TODO: publish a multi-platform manylinux image and include more deps in all platforms (rust, ccache, etc.) + MANYLINUX_DOCKER_IMAGE="${MANYLINUX_DOCKER_IMAGE:-quay.io/pypa/manylinux_2_28_${ARCH}:latest}" +fi + function run_on_host() { echo "Running on host" echo "Launching docker image ${MANYLINUX_DOCKER_IMAGE}" From aab71618d6d0926720d68d7b9f2c5810fc1e6f86 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 17 Dec 2024 21:03:26 +0100 Subject: [PATCH 24/39] [shortfin] Bump fmt and spdlog and test with GCC 14 (#711) Bumps libftm to 11.0.2 to mitigate a build error with occurring with GCC 14. Bumping spdlog to 1.15.0 (which bundles libfmt 11.0.2) accordingly to keep the libs in sync. Furthermore expands testing to build with GCC 14. --- .github/workflows/ci-libshortfin.yml | 19 ++++++++++++------- shortfin/CMakeLists.txt | 8 ++++++-- shortfin/src/shortfin/array/array.cc | 1 + shortfin/src/shortfin/local/device.cc | 1 + shortfin/src/shortfin/local/fiber.cc | 1 + shortfin/src/shortfin/local/program.cc | 1 + shortfin/src/shortfin/local/system.cc | 1 + shortfin/src/shortfin/local/systems/amdgpu.cc | 2 ++ .../src/shortfin/local/systems/factory.cc | 2 ++ shortfin/src/shortfin/support/config.cc | 1 + 10 files changed, 28 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci-libshortfin.yml b/.github/workflows/ci-libshortfin.yml index 0e0982803..543a6abe6 100644 --- a/.github/workflows/ci-libshortfin.yml +++ b/.github/workflows/ci-libshortfin.yml @@ -38,7 +38,7 @@ jobs: strategy: fail-fast: false matrix: - name: ["Ubuntu (Clang)(full)", "Ubuntu (Clang)(host-only)", "Ubuntu (GCC)", "Windows (MSVC)"] + name: ["Ubuntu (Clang)(full)", "Ubuntu (Clang)(host-only)", "Windows (MSVC)"] python-version: ["3.10", "3.11", "3.12"] include: - name: Ubuntu (Clang)(full) @@ -53,16 +53,21 @@ jobs: cmake-options: -DCMAKE_C_COMPILER=clang-18 -DCMAKE_CXX_COMPILER=clang++-18 -DCMAKE_LINKER_TYPE=LLD -DSHORTFIN_HAVE_AMDGPU=OFF -DSHORTFIN_BUILD_STATIC=ON -DSHORTFIN_BUILD_DYNAMIC=ON additional-packages: clang lld - - name: Ubuntu (GCC) + - name: Ubuntu (GCC 13) runs-on: ubuntu-24.04 + # Only test with GCC 13 and Python 3.12 + python-version: "3.12" + cmake-options: + -DCMAKE_C_COMPILER=gcc-13 -DCMAKE_CXX_COMPILER=g++-13 + - name: Ubuntu (GCC 14) + runs-on: ubuntu-24.04 + # Only test with GCC 14 and Python 3.12 + python-version: "3.12" + cmake-options: + -DCMAKE_C_COMPILER=gcc-14 -DCMAKE_CXX_COMPILER=g++-14 - name: Windows (MSVC) runs-on: windows-2022 exclude: - # Only test Python 3.12 with GCC - - name: Ubuntu (GCC) - python-version: "3.10" - - name: Ubuntu (GCC) - python-version: "3.11" # TODO: Include additional Python versions for Windows after build got fixed - name: Windows (MSVC) python-version: "3.10" diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt index 2c79d5b41..bd46d84f9 100644 --- a/shortfin/CMakeLists.txt +++ b/shortfin/CMakeLists.txt @@ -39,6 +39,10 @@ if(NOT WIN32) set(CMAKE_POSITION_INDEPENDENT_CODE ON) endif() +# For unicode support Windows libfmt requires compiling with /utf-8. +add_compile_options("$<$:/utf-8>") +add_compile_options("$<$:/utf-8>") + # Pins set(SHORTFIN_IREE_GIT_TAG "iree-3.1.0rc20241204") @@ -140,7 +144,7 @@ if(SHORTFIN_BUNDLE_DEPS) FetchContent_Declare( fmt GIT_REPOSITORY https://github.com/fmtlib/fmt.git - GIT_TAG e69e5f977d458f2650bb346dadf2ad30c5320281 # 10.2.1 (sync with spdlog) + GIT_TAG 0c9fce2ffefecfdce794e1859584e25877b7b592 # 11.0.2 (sync with spdlog) ) ## spdlog @@ -149,7 +153,7 @@ if(SHORTFIN_BUNDLE_DEPS) FetchContent_Declare( spdlog GIT_REPOSITORY https://github.com/gabime/spdlog.git - GIT_TAG 2d4acf8cc321d7783d8f2e22e17a794c6d0e9450 # v1.14.1 + GIT_TAG 8e5613379f5140fefb0b60412fbf1f5406e7c7f8 # v1.15.0 ) ## xtl: required for xtensor diff --git a/shortfin/src/shortfin/array/array.cc b/shortfin/src/shortfin/array/array.cc index 882e4ef39..c0eca52d0 100644 --- a/shortfin/src/shortfin/array/array.cc +++ b/shortfin/src/shortfin/array/array.cc @@ -10,6 +10,7 @@ #include "fmt/core.h" #include "fmt/ranges.h" +#include "fmt/xchar.h" #include "shortfin/array/xtensor_bridge.h" #include "shortfin/support/logging.h" diff --git a/shortfin/src/shortfin/local/device.cc b/shortfin/src/shortfin/local/device.cc index 3afd2b8ad..1bed3a419 100644 --- a/shortfin/src/shortfin/local/device.cc +++ b/shortfin/src/shortfin/local/device.cc @@ -8,6 +8,7 @@ #include #include +#include namespace shortfin::local { diff --git a/shortfin/src/shortfin/local/fiber.cc b/shortfin/src/shortfin/local/fiber.cc index 8ad9f2960..2c03672fd 100644 --- a/shortfin/src/shortfin/local/fiber.cc +++ b/shortfin/src/shortfin/local/fiber.cc @@ -8,6 +8,7 @@ #include #include +#include #include "shortfin/local/system.h" #include "shortfin/support/logging.h" diff --git a/shortfin/src/shortfin/local/program.cc b/shortfin/src/shortfin/local/program.cc index 6ab1f47ae..71452da3e 100644 --- a/shortfin/src/shortfin/local/program.cc +++ b/shortfin/src/shortfin/local/program.cc @@ -8,6 +8,7 @@ #include "fmt/core.h" #include "fmt/std.h" +#include "fmt/xchar.h" #include "iree/io/formats/parser_registry.h" #include "iree/modules/hal/module.h" #include "iree/modules/io/parameters/module.h" diff --git a/shortfin/src/shortfin/local/system.cc b/shortfin/src/shortfin/local/system.cc index ef31bb001..00fcf4c65 100644 --- a/shortfin/src/shortfin/local/system.cc +++ b/shortfin/src/shortfin/local/system.cc @@ -7,6 +7,7 @@ #include "shortfin/local/system.h" #include +#include #include "iree/hal/utils/allocators.h" #include "shortfin/local/fiber.h" diff --git a/shortfin/src/shortfin/local/systems/amdgpu.cc b/shortfin/src/shortfin/local/systems/amdgpu.cc index cecedd1a0..262d2ec62 100644 --- a/shortfin/src/shortfin/local/systems/amdgpu.cc +++ b/shortfin/src/shortfin/local/systems/amdgpu.cc @@ -6,6 +6,8 @@ #include "shortfin/local/systems/amdgpu.h" +#include + #include "shortfin/support/logging.h" #include "shortfin/support/sysconfig.h" diff --git a/shortfin/src/shortfin/local/systems/factory.cc b/shortfin/src/shortfin/local/systems/factory.cc index bf5b788dc..c5ee036cd 100644 --- a/shortfin/src/shortfin/local/systems/factory.cc +++ b/shortfin/src/shortfin/local/systems/factory.cc @@ -4,6 +4,8 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include + #include "shortfin/local/system.h" #include "shortfin/support/logging.h" diff --git a/shortfin/src/shortfin/support/config.cc b/shortfin/src/shortfin/support/config.cc index 7de820d1c..d188ddb16 100644 --- a/shortfin/src/shortfin/support/config.cc +++ b/shortfin/src/shortfin/support/config.cc @@ -12,6 +12,7 @@ #include #include "fmt/format.h" +#include "fmt/xchar.h" #include "shortfin/support/logging.h" namespace shortfin { From c4a592ac8bcb2202a554ab1a4d311fdf5ddf28eb Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Tue, 17 Dec 2024 12:35:11 -0800 Subject: [PATCH 25/39] [sharktank] Update block_seq_stride for perplexity CI tests (#707) - Update `block_seq_stride` for perplexity CI tests - Update default value of `block_seq_stride` from `16` to `32` in `export_paged_llm_v1.py` --- .github/workflows/ci_eval.yaml | 4 +- .github/workflows/ci_eval_short.yaml | 2 +- app_tests/integration_tests/llm/utils.py | 1 + .../sharktank/evaluate/perplexity_iree.py | 47 +++++++++++-------- .../sharktank/examples/export_paged_llm_v1.py | 2 +- .../sharktank/layers/configs/llm_configs.py | 2 +- sharktank/sharktank/utils/export_artifacts.py | 5 +- 7 files changed, 36 insertions(+), 27 deletions(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 3b85cb652..a71698774 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -24,7 +24,7 @@ jobs: test_perplexity_iree: if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} timeout-minutes: 1000 - name: "Perplexity-IREE" + name: "IREE Perplexity" strategy: matrix: version: [3.11] @@ -83,7 +83,7 @@ jobs: test_perplexity_torch: if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} timeout-minutes: 1000 - name: "Perplexity-Torch" + name: "Torch Perplexity" strategy: matrix: version: [3.11] diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml index edaaee966..d5f8f5682 100644 --- a/.github/workflows/ci_eval_short.yaml +++ b/.github/workflows/ci_eval_short.yaml @@ -23,7 +23,7 @@ concurrency: jobs: test_perplexity_iree: - name: "Llama3.1 8B FP16" + name: "IREE Perplexity" strategy: matrix: version: [3.11] diff --git a/app_tests/integration_tests/llm/utils.py b/app_tests/integration_tests/llm/utils.py index 80b5b3c09..dbbdee10d 100644 --- a/app_tests/integration_tests/llm/utils.py +++ b/app_tests/integration_tests/llm/utils.py @@ -90,6 +90,7 @@ def export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes): "python", "-m", "sharktank.examples.export_paged_llm_v1", + "--block-seq-stride=16", f"--{model_path.suffix.strip('.')}-file={model_path}", f"--output-mlir={mlir_path}", f"--output-config={config_path}", diff --git a/sharktank/sharktank/evaluate/perplexity_iree.py b/sharktank/sharktank/evaluate/perplexity_iree.py index 6060eb91b..c47726f0e 100644 --- a/sharktank/sharktank/evaluate/perplexity_iree.py +++ b/sharktank/sharktank/evaluate/perplexity_iree.py @@ -68,12 +68,14 @@ def __init__( kv_cache_type, tensor_parallelism_size, attention_kernel, + block_seq_stride, ): self.torch_device = torch_device self.iree_device = iree_device self.iree_hip_target = iree_hip_target self.iree_hal_target_backends = iree_hal_target_backends self.kv_cache_type = kv_cache_type + self.block_seq_stride = block_seq_stride self.activation_dtype = torch.float16 self.attention_dtype = torch.float16 self.tensor_parallelism_size = tensor_parallelism_size @@ -136,6 +138,7 @@ def compile_model(self, weight_path_str): iree_hal_target_backends=self.iree_hal_target_backends, attention_kernel=self.attention_kernel, tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=self.block_seq_stride, ) vmfb_path = export_artifacts.get_artifacts() return vmfb_path @@ -145,7 +148,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path): self.config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(weight_path.properties), - block_seq_stride=16, + block_seq_stride=self.block_seq_stride, kv_cache_type=self.kv_cache_type, device=self.torch_device, activation_dtype=self.activation_dtype, @@ -394,6 +397,7 @@ def run_perplexity( tensor_parallelism_size, attention_kernel, num_prompts, + block_seq_stride, ): start = time.time() perplexity = Perplexity( @@ -404,6 +408,7 @@ def run_perplexity( kv_cache_type=kv_cache_type, tensor_parallelism_size=tensor_parallelism_size, attention_kernel=attention_kernel, + block_seq_stride=block_seq_stride, ) perplexity.get_prompts(num_prompts=num_prompts) @@ -425,8 +430,18 @@ def run_perplexity( def main(argv): parser = cli.create_parser() - parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") - parser.add_argument("--torch-device", help="Torch device (or default)") + parser.add_argument( + "--attention-kernel", + type=str, + default="decomposed", + choices=["decomposed", "torch_sdpa"], + ) + parser.add_argument( + "--block-seq-stride", + help="Block sequence stride for paged KV cache, must divide evenly into the context length", + type=int, + default=32, + ) parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')") parser.add_argument( "--iree-hip-target", @@ -440,11 +455,12 @@ def main(argv): default="rocm", help="Specify the iree-hal target backends (e.g., rocm)", ) + parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") parser.add_argument( - "--attention-kernel", - type=str, - default="decomposed", - choices=["decomposed", "torch_sdpa"], + "--num-prompts", + type=int, + default=100, + help="Number of prompts for perplexity test (1 to 100)", ) parser.add_argument( "--tensor-parallelism-size", @@ -452,36 +468,29 @@ def main(argv): default=1, help="Number of devices for tensor parallel sharding", ) - parser.add_argument( - "--num-prompts", - type=int, - default=100, - help="Number of prompts for perplexity test", - ) + parser.add_argument("--torch-device", help="Torch device (or default)") cli.add_tokenizer_options(parser) cli.add_input_dataset_options(parser) args = cli.parse(parser, args=argv) torch_device = torch.device(args.torch_device) if args.torch_device else None - iree_device = args.iree_device - kv_cache_type = args.kv_cache_type weight_path = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) - weight_path_str = str(args.irpa_file) ppl = run_perplexity( weight_path=weight_path, - weight_path_str=weight_path_str, + weight_path_str=str(args.irpa_file), tokenizer=tokenizer, torch_device=torch_device, - iree_device=iree_device, + iree_device=args.iree_device, iree_hip_target=args.iree_hip_target, iree_hal_target_backends=args.iree_hal_target_backends, - kv_cache_type=kv_cache_type, + kv_cache_type=args.kv_cache_type, tensor_parallelism_size=args.tensor_parallelism_size, attention_kernel=args.attention_kernel, num_prompts=args.num_prompts, + block_seq_stride=args.block_seq_stride, ) logger.info(f"\n{json.dumps(ppl, indent=2)}") diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index ad297bcce..056d8a98e 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -49,7 +49,7 @@ def main(): "--block-seq-stride", help="Block sequence stride for paged KV cache, must divide evenly into the context length", type=int, - default="16", + default=32, ) parser.add_argument( "--verbose", diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 88f5c344c..6cf79402e 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -144,7 +144,7 @@ class LlamaModelConfig: # Block sequence stride for a paged KV cache. This must divide evenly # into the context length. - block_seq_stride: int = 16 + block_seq_stride: int = 32 # Either "paged" or "direct". kv_cache_type: str = "paged" diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 0bf252525..75cdbab7a 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -92,7 +92,7 @@ def __init__( iree_hal_target_backends: str, attention_kernel: str, tensor_parallelism_size: int, - block_seq_stride: Optional[int] = None, + block_seq_stride: int, ): self.sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent @@ -180,14 +180,13 @@ def export_to_mlir( f"--output-mlir={mlir_path}", f"--output-config={json_path}", f"--bs={str(self.batch_size)}", + f"--block-seq-stride={self.block_seq_stride}", ] if skip_decode: export_args.append("--skip-decode") if self.attention_kernel in ["decomposed", "torch"]: export_args.append("--attention-kernel") export_args.append(self.attention_kernel) - if self.block_seq_stride: - export_args.append(f"--block-seq-stride={self.block_seq_stride}") cwd = self.sharktank_dir cmd = subprocess.list2cmdline(export_args) From aaee29a84760b2b00cee93e3873dcb2ed773ecb6 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Wed, 18 Dec 2024 15:42:12 -0500 Subject: [PATCH 26/39] [tuner] Add unified benchmarking and compilation for models + dispatches (#704) This PR adds `benchmark()` and `compile()` functions to the tuner that can be used for both model and dispatch tuning. The new functions will replace the split benchmark/compile_models and benchmark/compile_dispatches functions. The new benchmarking and compilation functions now use the iree_runtime/iree_compiler python bindings which makes much of the code simpler. Particularly, benchmark results are now mostly parsed by the bindings, and the parse_*_benchmark_results functions are no longer needed. The new compilation and benchmarking flows are described below. ### Compilation ### 1. Populate each CandidateTracker with the input and output filepaths. The input filepaths can be overridden by an optional function argument to the compile() function. This argument can be used for model tuning, passing the model filepath as the new input file. 2. For each candidate, strip the compilation info using iree-opt, and compile to a vmfb with the iree compiler python bindings. Set the candidate's TD spec file (generated during candidate generation), and add any additional iree-compile flags that came from the TuningClient. The extra flags are taken from a new abstract TuningClient function called get_iree_compile_flags. 3. For all successful compilations, save the vmfbs to the designated output path, and skip any failed compilation. For any failed compilation, a failure dump is saved instead of the vmfb. 4. Remove duplicate vmfbs, and return the ids of all unique candidates. ### Benchmarking ### 1. Create benchmark task structs for each candidate with its CandidateTracker and the TuningClient 2. Run the candidate benchmarks on the available devices. Each benchmark task will benchmark the vmfb from the CandidateTracker using the iree_runtime python bindings, and return a benchmark result containing the candidate_id, benchmark time, and device_id. 3. Then the same benchmarking is done on the untuned baseline configuration once for each available device. 4. The results from the candidate benchmarks are compared with the baseline benchmarks from the same device, and the fastest candidates are logged and returned. The number of candidates returned is determined by an optional argument to the benchmark function, and all candidates will be returned by default. --------- Signed-off-by: Max Dawkins --- tuner/examples/dispatch/dispatch_tuner.py | 9 + tuner/examples/punet/punet_autotune.py | 9 + tuner/examples/test/README.md | 39 +++ tuner/examples/test/double_mmt.mlir | 16 + tuner/examples/test/tuner_test.py | 133 +++++++- tuner/tuner/candidate_gen.py | 14 +- tuner/tuner/dispatch_parser.py | 6 +- tuner/tuner/libtuner.py | 386 +++++++++++++++++++++- tuner/tuner/op_matchers.py | 5 +- tuner/tuner/spec_builder.py | 57 +++- 10 files changed, 646 insertions(+), 28 deletions(-) create mode 100644 tuner/examples/test/README.md create mode 100644 tuner/examples/test/double_mmt.mlir diff --git a/tuner/examples/dispatch/dispatch_tuner.py b/tuner/examples/dispatch/dispatch_tuner.py index 3c2d77f64..0f5b54979 100644 --- a/tuner/examples/dispatch/dispatch_tuner.py +++ b/tuner/examples/dispatch/dispatch_tuner.py @@ -79,6 +79,15 @@ def get_model_benchmark_command( ) -> list[str]: return [] + def get_iree_compile_flags(self) -> list[str]: + return [] + + def get_iree_benchmark_module_flags(self) -> list[str]: + return [] + + def get_benchmark_timeout_s(self) -> int: + return 0 + def main(): args = libtuner.parse_arguments() diff --git a/tuner/examples/punet/punet_autotune.py b/tuner/examples/punet/punet_autotune.py index 3503c86df..2bfdb4d24 100644 --- a/tuner/examples/punet/punet_autotune.py +++ b/tuner/examples/punet/punet_autotune.py @@ -113,6 +113,15 @@ def get_model_benchmark_command( ] return command + def get_iree_compile_flags(self) -> list[str]: + return [] + + def get_iree_benchmark_module_flags(self) -> list[str]: + return [] + + def get_benchmark_timeout_s(self) -> int: + return 0 + def main(): args = libtuner.parse_arguments() diff --git a/tuner/examples/test/README.md b/tuner/examples/test/README.md new file mode 100644 index 000000000..5dfba0da3 --- /dev/null +++ b/tuner/examples/test/README.md @@ -0,0 +1,39 @@ +# Example Tuner Test + +Example of tuning a dispatch and full model. + +## Environments +Follow instructions in [`/tuner/README.md`](../README.md) + +## Running the Tuner + +### Choose a model to tune +This example uses the simple `double_mmt.mlir` file. + +### Generate a benchmark file +Use the usual `iree-compile` command for your model, add +`--iree-hal-dump-executable-files-to=dump --iree-config-add-tuner-attributes`, +and get the dispatch benchmark that you want to tune. For example: +```shell +iree-compile double_mmt.mlir --iree-hal-target-backends=rocm \ + --iree-hip-target=gfx942 --iree-hal-dump-executable-files-to=dump \ + --iree-config-add-tuner-attributes -o /dev/null + +cp dump/module_main_dispatch_0_rocm_hsaco_fb_benchmark.mlir mmt_benchmark.mlir +``` + +### Recommended Trial Run +For an initial trial to test the tuning loop, use: +```shell +python -m examples.test double_mmt.mlir mmt_benchmark.mlir \ + --test_num_dispatch_candidates=5 --test_num_model_candidates=3 \ + --num-candidates=30 +``` + +### Basic Usage +```shell +python -m examples.test \ + --test_num_dispatch_candidates= \ + --test_num_model_candidates= \ + --test_hip_target= \ --num-candidates= +``` diff --git a/tuner/examples/test/double_mmt.mlir b/tuner/examples/test/double_mmt.mlir new file mode 100644 index 000000000..a3bd4c7b0 --- /dev/null +++ b/tuner/examples/test/double_mmt.mlir @@ -0,0 +1,16 @@ +!matA_0 = tensor<2048x2048xf16> +!matB_0 = tensor<2048x2048xf16> +!matC_0 = tensor<2048x2048xf32> + +!matC_1 = tensor<2048x2048xf32> + +func.func @main(%arg0: !matA_0, %arg1: !matB_0) -> !matC_1 { + %cst = arith.constant 0.000000e+00 : f32 + %5 = tensor.empty() : !matC_0 + %6 = linalg.fill ins(%cst : f32) outs(%5 : !matC_0) -> !matC_0 + %7 = linalg.matmul_transpose_b ins(%arg0, %arg1 : !matA_0, !matB_0) outs(%6 : !matC_0) -> !matC_0 + %8 = tensor.empty() : !matC_1 + %9 = linalg.fill ins(%cst : f32) outs(%8 : !matC_1) -> !matC_1 + %10 = linalg.matmul_transpose_b ins(%7, %7 : !matC_0, !matC_0) outs(%9 : !matC_1) -> !matC_1 + return %10 : !matC_1 +} diff --git a/tuner/examples/test/tuner_test.py b/tuner/examples/test/tuner_test.py index d8c35d60b..528f03b80 100644 --- a/tuner/examples/test/tuner_test.py +++ b/tuner/examples/test/tuner_test.py @@ -4,15 +4,94 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import argparse +from pathlib import Path from tuner import libtuner +class TestTuner(libtuner.TuningClient): + def __init__(self): + super().__init__() + self.compile_flags = ["--compile-from=executable-sources"] + self.benchmark_flags = ["--benchmark_repetitions=3", "--input=1"] + + def get_iree_compile_flags(self) -> list[str]: + return self.compile_flags + + def get_iree_benchmark_module_flags(self) -> list[str]: + return self.benchmark_flags + + def get_benchmark_timeout_s(self) -> int: + return 10 + + # TODO(Max191): Remove the following unused abstract functions once they + # are removed from the TuningClient definition. + def get_dispatch_benchmark_timeout_s(self) -> int: + return 0 + + def get_dispatch_compile_timeout_s(self) -> int: + return 0 + + def get_dispatch_compile_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + return [] + + def get_dispatch_benchmark_command( + self, + candidate_tracker: libtuner.CandidateTracker, + ) -> list[str]: + return [] + + def get_model_compile_timeout_s(self) -> int: + return 0 + + def get_model_compile_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + return [] + + def get_model_benchmark_timeout_s(self) -> int: + return 0 + + def get_model_benchmark_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + return [] + + def main(): - args = libtuner.parse_arguments() + # Custom arguments for the test file. + parser = argparse.ArgumentParser(description="Autotune test script") + test_args = parser.add_argument_group("Example Test Options") + test_args.add_argument( + "test_model_file", type=Path, help="Path to the model file to tune (.mlir)" + ) + test_args.add_argument( + "--test_num_dispatch_candidates", + type=int, + default=None, + help="Number of dispatch candidates to keep for model benchmarks.", + ) + test_args.add_argument( + "--test_num_model_candidates", + type=int, + default=None, + help="Number of model candidates to produce after tuning.", + ) + test_args.add_argument( + "--test_hip_target", + type=str, + default="gfx942", + help="Hip target for tuning.", + ) + # Remaining arguments come from libtuner + args = libtuner.parse_arguments(parser) path_config = libtuner.PathConfig() path_config.base_dir.mkdir(parents=True, exist_ok=True) path_config.output_unilog.touch() + # TODO(Max191): Make candidate_trackers internal to TuningClient. candidate_trackers: list[libtuner.CandidateTracker] = [] stop_after_phase: str = args.stop_after @@ -20,19 +99,69 @@ def main(): libtuner.setup_logging(args, path_config) print(path_config.run_log, end="\n\n") + # TODO(Max191): Some bug seems to be causing OOM errors in benchmarking + # when device validation happens, so this is commented for now. Uncomment + # when the bug is fixed. if not args.dry_run: print("Validating devices") libtuner.validate_devices(args.devices) print("Validation successful!\n") print("Generating candidates...") + test_tuner = TestTuner() candidates = libtuner.generate_candidate_specs( - args, path_config, candidate_trackers + args, path_config, candidate_trackers, test_tuner ) print(f"Stored candidate specs in {path_config.specs_dir}\n") if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: return + print("Compiling candidates...") + compiled_candidates = libtuner.compile( + args, path_config, candidates, candidate_trackers, test_tuner + ) + + print("Benchmarking compiled candidates...") + top_candidates = libtuner.benchmark( + args, + path_config, + compiled_candidates, + candidate_trackers, + test_tuner, + args.test_num_dispatch_candidates, + ) + + print("Compiling models with top candidates...") + test_tuner.compile_flags = [ + "--iree-hal-target-backends=rocm", + f"--iree-hip-target={args.test_hip_target}", + ] + compiled_model_candidates = libtuner.compile( + args, + path_config, + top_candidates, + candidate_trackers, + test_tuner, + args.test_model_file, + ) + + print("Benchmarking compiled model candidates...") + test_tuner.benchmark_flags = [ + "--benchmark_repetitions=3", + "--input=2048x2048xf16", + "--input=2048x2048xf16", + ] + top_model_candidates = libtuner.benchmark( + args, + path_config, + compiled_model_candidates, + candidate_trackers, + test_tuner, + args.test_num_model_candidates, + ) + + print(f"Top model candidates: {top_model_candidates}") + print("Check the detailed execution logs in:") print(path_config.run_log.resolve()) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index ed4d63f7d..45cb3512a 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -118,14 +118,18 @@ def get_td_spec( class DispatchTunerRegistry: - def __init__(self): + def __init__(self, check_translation_info=True): + self.check_translation_info = check_translation_info self.registry = set() def register(self, dispatch_tuners: list[DispatchTuner]) -> None: for dispatch_tuner in dispatch_tuners: self.registry.add(dispatch_tuner) + # TODO(Max191): Remove translation info validation. def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool: + if not self.check_translation_info: + return True for attr in attrs: if (attr.name == "translation_info") and ( "LLVMGPUVectorDistribute" in str(attr.attr) @@ -641,7 +645,7 @@ def generate_configs_and_td_specs( limit: int = 4096, # Max candidates to be generated num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints ) -> list[ir.Module]: - dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry = DispatchTunerRegistry(check_translation_info=False) dispatch_tuner_registry.register( [ ContractionOpInterfaceTuner(), @@ -658,10 +662,8 @@ def generate_configs_and_td_specs( ) tune_logger.debug(str(problem_size)) - # Index 0 is reserved for default config, so it gets no td spec. - with ir.Location.unknown() as loc: - empty_module = ir.Module.create(loc) - config_specs: list[ir.Module] = [empty_module] + # Index 0 is reserved for default config, so it gets a placeholder spec. + config_specs: list[ir.Module] = [get_placeholder_spec(input_module.context)] # Get the MMA intrinisic intructions supported by the target. variant_op_list = iree_codegen.get_executable_variant_ops(input_module) diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index b45771166..735d6145c 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -110,8 +110,7 @@ def get_contraction_operation( # TODO(Max191): Pass the ir_module directly instead of the template str. def get_shapes(self, template: list[str]) -> ProblemSize: matcher = ContractionOpInterfaceMatcher() - with ir.Context() as ctx: - ir_module = ir.Module.parse("\n".join(template), ctx) + ir_module = ir.Module.parse("\n".join(template)) contraction_op = match_root_op(ir_module, matcher) assert contraction_op is not None, f"contraction op not found" cdims = matcher.contraction_dimensions @@ -161,8 +160,7 @@ def get_conv_operation( # TODO(Max191): Pass the ir_module directly instead of the template str. def get_shapes(self, template: list[str]) -> ProblemSize: - with ir.Context() as ctx: - ir_module = ir.Module.parse("\n".join(template), ctx) + ir_module = ir.Module.parse("\n".join(template)) conv_op = match_root_op(ir_module, NamedOpMatcher(self.supported_ops)) assert conv_op is not None, f"convolution op not found" lhs_type = ir.RankedTensorType(conv_op.operands[0].type) diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 6bece17f4..ff7b78a11 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -18,6 +18,8 @@ """ +import math +import signal import sys import shutil import subprocess @@ -39,9 +41,11 @@ import json from abc import ABC, abstractmethod import iree.runtime as ireert # type: ignore +import iree.compiler as ireec # type: ignore from iree.compiler import ir # type: ignore from . import candidate_gen from . import dispatch_parser +from .op_matchers import * from .common import * @@ -62,6 +66,8 @@ DEVICE_ID_PLACEHOLDER = "!DEVICE_ID!" +# TODO(Max191): Remove most of the fields here after refactoring is complete, +# since many of them will be unused. @dataclass class CandidateTracker: candidate_id: int @@ -70,6 +76,7 @@ class CandidateTracker: dispatch_config_path: Optional[Path] = None configuration: Optional[candidate_gen.iree_codegen.CompilationInfoAttr] = None compilation_successful: Optional[bool] = None + compiled_vmfb_path: Optional[Path] = None compiled_dispatch_path: Optional[Path] = None compiled_dispatch_hash: Optional[str] = None first_benchmark_time: Optional[float] = None @@ -155,11 +162,27 @@ def get_compiled_dispatch_index(self, file_path: Path) -> int: def get_candidate_spec_filename(self, candidate_id: int) -> str: return f"{candidate_id}_spec.mlir" + def get_candidate_vmfb_filename(self, candidate_id: int) -> str: + return f"{candidate_id}.vmfb" + def get_compiled_model_index(self, file_path: Path) -> int: return int(file_path.stem.split("_")[-1]) class TuningClient(ABC): + def __init__(self): + mlir_ctx = ir.Context() + logger = logging.getLogger("tune") + self.tuner_context = TunerContext(mlir_ctx, logger) + + @abstractmethod + def get_iree_compile_flags(self) -> list[str]: + pass + + @abstractmethod + def get_iree_benchmark_module_flags(self) -> list[str]: + pass + @abstractmethod def get_dispatch_compile_command( self, candidate_tracker: CandidateTracker @@ -184,6 +207,10 @@ def get_model_benchmark_command( ) -> list[str]: pass + @abstractmethod + def get_benchmark_timeout_s(self) -> int: + pass + @abstractmethod def get_dispatch_compile_timeout_s(self) -> int: pass @@ -201,6 +228,19 @@ def get_model_benchmark_timeout_s(self) -> int: pass +@dataclass +class CompilePack: + iree_compile_flags: list[str] + candidate_tracker: CandidateTracker + + +@dataclass +class BenchmarkPack: + iree_benchmark_module_flags: list[str] + benchmark_timeout: int + candidate_tracker: CandidateTracker + + @dataclass class RunPack: command: list[str] @@ -237,6 +277,13 @@ class ParsedDisptachBenchmarkResult: candidate_spec_mlir: Path +@dataclass +class BenchmarkResult: + candidate_id: int + time: float + device_id: str + + @dataclass class IREEBenchmarkResult: # Default format follows output of iree-benchmark-module @@ -381,8 +428,12 @@ class ExecutionPhases(str, Enum): benchmark_models = "benchmark-models" -def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Autotune script") +def parse_arguments( + initial_parser: Optional[argparse.ArgumentParser] = None, +) -> argparse.Namespace: + parser = initial_parser + if parser is None: + parser = argparse.ArgumentParser(description="Autotune script") # Required arguments required_args = parser.add_argument_group("Required Options") @@ -598,6 +649,161 @@ def run_command(run_pack: RunPack) -> RunResult: return RunResult(result, is_timeout) +# The `strip_root_op_attr` and `strip_compilation_info` functions are used for +# getting consistent inputs to the compilation step in tuning. Inputs may come +# in with lowering configs, translation info, and root_op attrs when the input +# is a benchmark, but not when the input is a source MLIR file. Stripping the +# info makes the inputs to compilation consistent, and allows for overwriting +# the compilation info with generated TD specs during codegen. +def strip_root_op_attr(module: ir.Module): + root_ops: list[ir.Operation] = get_ops_from_module(module, is_root_op) + for root_op in root_ops: + assert ( + ROOT_OP_ATTR_NAME in root_op.opview.attributes + ), f"expected root op to have '{ROOT_OP_ATTR_NAME}' attr" + del root_op.opview.attributes[ROOT_OP_ATTR_NAME] + + +# See the above comment for `strip_root_op_attr`. +def strip_compilation_info(input_path: Path) -> str: + # Strip compilation info from the source and save the stripped IR + strip_command = [ + f"iree-opt", + f"{input_path}", + f"--iree-codegen-strip-compilation-info", + ] + result = run_command( + RunPack( + command=strip_command, + check=True, + ) + ) + assert ( + result.process_res is not None + ), "expected result from stripping compilation info" + return result.process_res.stdout + + +def run_iree_compile_command(compile_pack: CompilePack) -> Optional[int]: + candidate_tracker = compile_pack.candidate_tracker + + # Compile to vmfb. + assert candidate_tracker.spec_path, "expected candidate spec path" + td_spec_path = candidate_tracker.spec_path.as_posix() + logging.debug( + f"Compiling candidate {candidate_tracker.candidate_id} with spec: {td_spec_path}" + ) + extra_flags = [ + f"--iree-codegen-tuning-spec-path={td_spec_path}", + ] + extra_flags += compile_pack.iree_compile_flags + assert candidate_tracker.compiled_vmfb_path, "expected output vmfb path" + output_path = candidate_tracker.compiled_vmfb_path.as_posix() + crash_dump_path = f"{output_path}.crash_report.mlir" + assert candidate_tracker.mlir_path, "expected input mlir file path" + input_file = candidate_tracker.mlir_path.as_posix() + # TODO(Max191): Make the device in `traget_backends` a command line option + # instead of hardcoding in ireec.compile_str. + try: + ireec.compile_file( + input_file=input_file, + target_backends=["rocm"], + output_file=output_path, + extra_args=extra_flags, + crash_reproducer_path=crash_dump_path, + ) + except ireec.CompilerToolError as e: + logging.info(f"Compilation returned non-zero exit status.") + logging.debug(e) + return None + + return candidate_tracker.candidate_id + + +def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack): + candidate_tracker = benchmark_pack.candidate_tracker + candidate_id = candidate_tracker.candidate_id + + # Load the candidate's vmfb and create vm_module. + vmfb_path = candidate_tracker.compiled_vmfb_path + assert vmfb_path is not None, "expected compiled_vmfb_path" + with open(vmfb_path, "rb") as f: + vmfb_buffer = f.read() + + vm_instance = ireert.VmInstance() + vm_module = ireert.VmModule.copy_buffer(vm_instance, vmfb_buffer) + + # Parse the flags passed from the tuning client and create a kwargs dict + # for the benchmark_module function. + extra_flags = {} + func_name = None + inputs = [] + for flag in benchmark_pack.iree_benchmark_module_flags: + assert flag[:2] == "--", "iree_benchmark_module_flags should begin with '--'" + split_key_value = flag[2:].split("=") + assert ( + len(split_key_value) == 2 + ), "iree_benchmark_module_flags should have the format --=" + key = split_key_value[0] + value = split_key_value[1] + # Allow the tuning client to pass `--function=@func_name`. + if key == "function": + func_name = value + continue + # Special handling for `--input`, since it can be passed many times. + if key == "input": + inputs.append(value) + continue + # Other flags become normal kwargs. + extra_flags[key] = value + + # Benchmark the module. + try: + timeout = benchmark_pack.benchmark_timeout + benchmark_results = ireert.benchmark.benchmark_module( + vm_module, + entry_function=func_name, + inputs=inputs, + device=device_id, + timeout=timeout, + **extra_flags, + ) + except ireert.benchmark.BenchmarkTimeoutError as e: + logging.warning( + f"Benchmark of candidate {candidate_id} timed out after {timeout} seconds." + ) + return BenchmarkResult( + candidate_id=candidate_id, + time=math.inf, + device_id=str(device_id), + ) + + times = [] + for benchmark_result in benchmark_results: + benchmark_name = benchmark_result.benchmark_name + # With multiple benchmark results, there will be `real_time_mean`, but + # not with single iteration benchmark results, so ignore the mean time + # and compute the mean of `real_time`, since the number of iterations + # is up to the tuning client. + if benchmark_name.split("/")[-1] == "real_time": + time_and_unit = benchmark_result.time.split(" ") + assert ( + len(time_and_unit) == 2 + ), "expected the benchmark time to be the time and unit separated by a space." + time_us = IREEBenchmarkResult.unit_to_microseconds( + real_time=float(time_and_unit[0]), + time_unit=time_and_unit[1], + ) + times.append(time_us) + mean_benchmark_time = sum(times) / float(len(times)) + logging.debug(f"Benchmark time of candidate {candidate_id}: {mean_benchmark_time}") + return BenchmarkResult( + candidate_id=candidate_id, + time=mean_benchmark_time, + device_id=str(device_id), + ) + + def run_command_wrapper(task_pack: TaskPack) -> TaskResult: """Help handle extra requirements and record more data for run_command()""" if task_pack.command_need_device_id: @@ -634,9 +840,11 @@ def multiprocess_progress_wrapper( initializer_inputs = initializer_inputs or () # Create a multiprocessing pool + sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) with multiprocessing.Pool( num_worker, initializer, initializer_inputs ) as worker_pool: + signal.signal(signal.SIGINT, sigint_handler) # Use tqdm to create a progress bar with tqdm(total=len(task_list)) as pbar: try: @@ -834,28 +1042,31 @@ def generate_candidate_specs( args: argparse.Namespace, path_config: PathConfig, candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, ) -> list[int]: """Generate candidate transform dialect specs for tuning. Returns the list of candidate indexes""" logging.debug("generate_candidate_specs()") path_config.specs_dir.mkdir(parents=True, exist_ok=True) + shutil.copy(args.input_file, path_config.template_mlir) tune_logger = logging.getLogger("tune") # Generate transform dialect specs. try: - with open(args.input_file, "r") as f: - mlir_text = f.read() - with ir.Context() as ctx: - tuner_context = TunerContext(ctx, tune_logger) - mlir_module = dispatch_parser.parse_mlir(mlir_text, tuner_context) + # Strip compilation info before generating td_specs, since the generated + # td_specs can end up matching against the compilation info from the + # source mlir. + mlir_text = strip_compilation_info(path_config.template_mlir) + mlir_module = dispatch_parser.parse_mlir(mlir_text, tuning_client.tuner_context) + with tuning_client.tuner_context.mlir_ctx: logging.debug("Captured messages from candidate_gen.py:") config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs( input_module=mlir_module, - tuner_context=tuner_context, + tuner_context=tuning_client.tuner_context, limit=args.num_candidates, num_subgroups=args.num_subgroups, ) - logging.debug("candidate_gen.py ends") + logging.debug("candidate_gen.py ends") handle_error( condition=(len(config_specs) <= 1), msg="Failed to generate any candidates" ) @@ -871,7 +1082,7 @@ def generate_candidate_specs( with open(spec_path, "w") as f: f.write(str(spec)) new_candidate = CandidateTracker( - mlir_path=args.input_file, + mlir_path=path_config.template_mlir, candidate_id=candidate_num, spec_path=spec_path, ) @@ -908,6 +1119,89 @@ def collision_handler(index_hash_list: list[tuple[int, str]]) -> tuple[bool, lis return collision_detected, unique_indexes +def compile( + args: argparse.Namespace, + path_config: PathConfig, + candidates: list[int], + candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, + input_file: Optional[Path] = None, +) -> list[int]: + logging.debug("compile()") + + if not candidates: + logging.warning("No model candidates to compile.") + return [] + + # If `input_file` is not None, then replace the currently tracked template + # with the passed input mlir file. + if input_file is not None: + shutil.copy(input_file, path_config.template_mlir) + + # Strip compilation info and root_op attribute from the source and save + # the stripped IR, since the TD specs do not expect these attributes. + stripped_mlir = strip_compilation_info(path_config.template_mlir) + context = tuning_client.tuner_context.mlir_ctx + stripped_module = ir.Module.parse(stripped_mlir, context=context) + strip_root_op_attr(stripped_module) + stripped_mlir = str(stripped_module) + with open(path_config.template_mlir, "w") as f: + f.write(stripped_mlir) + + # Set the source and output file paths for compilation of each candidate. + path_config.compiled_dir.mkdir(parents=True, exist_ok=True) + for i in candidates: + vmfb_file_name = path_config.get_candidate_vmfb_filename( + candidate_trackers[i].candidate_id + ) + vmfb_path = path_config.compiled_dir / vmfb_file_name + candidate_trackers[i].compiled_vmfb_path = vmfb_path + candidate_trackers[i].mlir_path = path_config.template_mlir + candidate_trackers[0].mlir_path = path_config.template_mlir + + # Run compilation for all candidates. + task_list = [ + CompilePack( + iree_compile_flags=tuning_client.get_iree_compile_flags(), + candidate_tracker=candidate_trackers[i], + ) + for i in candidates + ] + if 0 not in candidates: + task_list.append( + CompilePack( + iree_compile_flags=tuning_client.get_iree_compile_flags(), + candidate_tracker=candidate_trackers[0], + ) + ) + num_worker = min(args.max_cpu_workers, len(task_list)) + compiled_candidates = multiprocess_progress_wrapper( + num_worker=num_worker, task_list=task_list, function=run_iree_compile_command + ) + compiled_candidates = [c for c in compiled_candidates if c is not None] + success_rate = float(len(compiled_candidates)) / float(len(candidates)) + logging.info( + f"Successfully compiled [{len(compiled_candidates)}] candidates. Success rate: {success_rate}" + ) + + # Remove duplicate vmfbs from the candidate list. + compiled_candidate_hashes = [] + for candidate_id in compiled_candidates: + candidate_vmfb = candidate_trackers[candidate_id].compiled_vmfb_path + hash_val = calculate_md5(candidate_vmfb) + compiled_candidate_hashes.append((candidate_id, hash_val)) + collision_detected, unique_compiled_candidates = collision_handler( + compiled_candidate_hashes + ) + if collision_detected: + compiled_candidates = unique_compiled_candidates + + logging.info(f"Produced [{len(compiled_candidates)}] unique vmfbs") + return compiled_candidates + + +# TODO(Max191): Remove in favor of using `compile` for both model and dispatch +# tuning. def compile_dispatches( args: argparse.Namespace, path_config: PathConfig, @@ -1095,6 +1389,7 @@ def generate_dryrun_model_benchmark_results( return candidate_results, baseline_results +# TODO(Max191): Remove this function in favor of `benchmark`. def benchmark_dispatches( args: argparse.Namespace, path_config: PathConfig, @@ -1172,6 +1467,76 @@ def benchmark_dispatches( return top_candidates +def benchmark( + args: argparse.Namespace, + path_config: PathConfig, + compiled_candidates: list[int], + candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, + num_candidates: Optional[int] = None, +): + logging.debug("benchmark()") + + task_list = [ + BenchmarkPack( + iree_benchmark_module_flags=tuning_client.get_iree_benchmark_module_flags(), + benchmark_timeout=tuning_client.get_benchmark_timeout_s(), + candidate_tracker=candidate_trackers[i], + ) + for i in compiled_candidates + if i != 0 + ] + worker_context_queue = create_worker_context_queue(args.devices) + candidate_results = multiprocess_progress_wrapper( + num_worker=len(args.devices), + task_list=task_list, + function=run_iree_benchmark_module_command, + initializer=init_worker_context, + initializer_inputs=(worker_context_queue,), + ) + + # Benchmarking baselines on each involved device. + worker_context_queue = create_worker_context_queue(args.devices) + baseline_task_list = [ + BenchmarkPack( + iree_benchmark_module_flags=tuning_client.get_iree_benchmark_module_flags(), + benchmark_timeout=tuning_client.get_benchmark_timeout_s(), + candidate_tracker=candidate_trackers[0], + ) + ] * len(args.devices) + baseline_results = multiprocess_progress_wrapper( + num_worker=len(args.devices), + task_list=baseline_task_list, + function=run_iree_benchmark_module_command, + initializer=init_worker_context, + initializer_inputs=(worker_context_queue,), + ) + baseline_times_by_device = {} + for r in baseline_results: + baseline_times_by_device[r.device_id] = r.time + + # Select top candidates + def get_speedup(result: BenchmarkResult) -> float: + return result.time / baseline_times_by_device[result.device_id] + + num_top_candidates = len(candidate_results) + if num_candidates is not None: + num_top_candidates = num_candidates + best_results = sorted(candidate_results, key=get_speedup)[:num_top_candidates] + logging.info(f"Selected top[{len(best_results)}]:") + + for r in best_results: + speedup = round(get_speedup(r) * 100, 2) + logging.info( + f"Candidate {r.candidate_id} time: {r.time} ({speedup}% of baseline)" + ) + + top_candidates = [result.candidate_id for result in best_results] + return top_candidates + + +# TODO(Max191): Remove in favor of using `compile` for both model and dispatch +# tuning. def compile_models( args: argparse.Namespace, path_config: PathConfig, @@ -1375,6 +1740,7 @@ def parse_model_benchmark_results( return dump_list +# TODO(Max191): Remove this function in favor of `benchmark`. def benchmark_models( args: argparse.Namespace, path_config: PathConfig, diff --git a/tuner/tuner/op_matchers.py b/tuner/tuner/op_matchers.py index 1abdafd3d..db953fbb3 100644 --- a/tuner/tuner/op_matchers.py +++ b/tuner/tuner/op_matchers.py @@ -39,9 +39,12 @@ def get_ops_from_module(module: ir.Module, fn): return ops +ROOT_OP_ATTR_NAME = "root_op" + + def is_root_op(op: ir.Operation) -> bool: for attr in op.opview.attributes: - if attr.name == "root_op": + if attr.name == ROOT_OP_ATTR_NAME: return True return False diff --git a/tuner/tuner/spec_builder.py b/tuner/tuner/spec_builder.py index a27bd072f..6005e89ae 100644 --- a/tuner/tuner/spec_builder.py +++ b/tuner/tuner/spec_builder.py @@ -13,6 +13,20 @@ from .common import * from .dispatch_constraints import * from .dispatch_parser import * +from .op_matchers import ROOT_OP_ATTR_NAME + + +def get_placeholder_spec(context: ir.Context) -> ir.Module: + spec_text = f""" + module attributes {{ transform.with_named_sequence }} {{ + transform.named_sequence + @__kernel_config(%variant_op: !transform.any_op {{transform.readonly}}) -> !transform.any_op + attributes {{ iree_codegen.tuning_spec_entrypoint }} {{ + transform.yield %variant_op : !transform.any_op + }} + }} + """ + return ir.Module.parse(spec_text, context) # TODO(Max191): Use python bindings to build the transform dialect spec module @@ -24,12 +38,43 @@ def build_td_spec( func_name: str, ) -> ir.Module: bbargs = [] + # The `root_op` attribute will prevent matching of ops without the attr in + # the resulting TD spec matcher if it is not removed, so we remove it here. + # After removing, we must add it back, since the op is connected to the + # input module, which gets used for all candidates. + # TODO(Max191): Find a cleaner way to do this without removing and adding + # back the attribute. + has_root_attr = ROOT_OP_ATTR_NAME in op.opview.attributes + if has_root_attr: + assert isinstance( + op.opview.attributes[ROOT_OP_ATTR_NAME], ir.UnitAttr + ), f"expected '{ROOT_OP_ATTR_NAME}' attr to be a unit attr" + if has_root_attr: + del op.opview.attributes[ROOT_OP_ATTR_NAME] + # Get the root op string for formatting the final spec. + root_operation = str(op) + if has_root_attr: + op.opview.attributes[ROOT_OP_ATTR_NAME] = ir.UnitAttr.get(op.context) + + # Get the names ssa names of operands to make sure they match in the + # template after string formatting. + captured_values: set[ir.Value] = set() for operand in op.operands: + if operand in captured_values: + # TODO(Max191): Remove this warning when the transform for the + # `cast_compatible_dag_from_root` op fixes a bug in the matching + # logic that causes failure to match when the same operand is + # repeated. For now, still avoid adding duplicate SSA values to + # prevent parsing failure. + logging.warning( + f"Root op has repeated operand. This can cause failure to match in the resulting TD spec at compile time." + ) + continue ssa_name = operand.get_name() operand_type = operand.type bbargs.append(f"{ssa_name}: {operand_type}") + captured_values.add(operand) bbargs_str = ", ".join(bbargs) - root_operation = str(op) spec_text = f""" module attributes {{ transform.with_named_sequence }} {{ // Annotation Transform @@ -51,11 +96,13 @@ def build_td_spec( }} // Entry Point - transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {{transform.consumed}}) {{ - transform.foreach_match in %variant_op + transform.named_sequence + @__kernel_config(%variant_op: !transform.any_op {{transform.consumed}}) -> !transform.any_op + attributes {{ iree_codegen.tuning_spec_entrypoint }} {{ + %res = transform.foreach_match in %variant_op @{func_name} -> @apply_op_config - : (!transform.any_op) -> (!transform.any_op) - transform.yield + : (!transform.any_op) -> !transform.any_op + transform.yield %res : !transform.any_op }} }} """ From 147228be8f98d037bd8f4d8733f0c75b24676de2 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 19 Dec 2024 00:17:45 +0100 Subject: [PATCH 27/39] [sharktank] Avoid torch pre-releases (#712) Drops passing `--pre` to avoid that pre-releases of torch get installed and switch to stable channel. --- .github/workflows/ci-sharktank.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index d46efef09..cc39f9b6b 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -74,7 +74,7 @@ jobs: # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --no-compile --pre --index-url https://download.pytorch.org/whl/test/cpu torch==${{matrix.torch-version}}+cpu + pip install --no-compile --index-url https://download.pytorch.org/whl/cpu torch==${{matrix.torch-version}}+cpu # Install nightly IREE packages. # We could also pin to a known working or stable version. From fecc081786eb07245c4c696ac17f3f41515da439 Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:32:26 -0800 Subject: [PATCH 28/39] [sharktank/shortfin] Schedule sharktank CI after IREE nightly release (#714) Schedule sharktank CI after IREE nightly release, which happens around 1.20 AM to 1.40 AM PST. --- .github/workflows/ci-llama-large-tests.yaml | 4 ++-- .github/workflows/ci-sglang-benchmark.yml | 4 ++-- .github/workflows/ci_eval.yaml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml index 2eb8e6496..4849dd188 100644 --- a/.github/workflows/ci-llama-large-tests.yaml +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -9,8 +9,8 @@ name: Llama Benchmarking Tests on: workflow_dispatch: schedule: - # Weekdays at 4:00 AM UTC = 9:00 PM PST. - - cron: "0 4 * * 1-5" + # Weekdays at 11:00 AM UTC = 03:00 AM PST / 04:00 AM PDT + - cron: "0 11 * * 1-5" concurrency: # A PR number if a pull request and otherwise the commit hash. This cancels diff --git a/.github/workflows/ci-sglang-benchmark.yml b/.github/workflows/ci-sglang-benchmark.yml index c549fc25a..483a28e25 100644 --- a/.github/workflows/ci-sglang-benchmark.yml +++ b/.github/workflows/ci-sglang-benchmark.yml @@ -21,8 +21,8 @@ name: SGLang Llama Benchmarking Tests on: workflow_dispatch: schedule: - # Weekdays at 4:00 AM UTC = 9:00 PM PST. - - cron: "0 4 * * 1-5" + # Weekdays at 11:00 AM UTC = 03:00 AM PST / 04:00 AM PDT + - cron: "0 11 * * 1-5" concurrency: # A PR number if a pull request and otherwise the commit hash. This cancels diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index a71698774..c8b782c95 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -9,8 +9,8 @@ name: CI - sharktank perplexity on: workflow_dispatch: schedule: - # Weekdays nightly at 07:00 UTC = 23:00 PST / 00:00 PDT. - - cron: "0 7 * * 1-5" + # Weekdays at 11:00 AM UTC = 03:00 AM PST / 04:00 AM PDT + - cron: "0 11 * * 1-5" concurrency: # A PR number if a pull request and otherwise the commit hash. This cancels From d520bd1c337cc2a4e5c79433439ba3cccbec5ed2 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 19 Dec 2024 12:26:27 -0800 Subject: [PATCH 29/39] Replication of index was causing issues for kv cache writes (#715) Signed-off-by: Rob Suderman --- .../sharktank/examples/export_paged_llm_v1.py | 3 ++- sharktank/sharktank/layers/kv_cache.py | 23 +++++++++++++++---- sharktank/sharktank/utils/cli.py | 7 +++++- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 056d8a98e..312a53d33 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -334,7 +334,8 @@ def _( bsizes = [] for bs in args.bs: - generate_batch_prefill(bs) + if not args.skip_prefill: + generate_batch_prefill(bs) if not args.skip_decode: generate_batch_decode(bs) bsizes.append(bs) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 46e94ff90..f62002f46 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -456,12 +456,25 @@ def write_timestep( page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1) # [1, 1] - partitions = torch.tensor(idx).unsqueeze(0) + if isinstance(seq_positions, ReplicatedTensor): + partitions = [ + torch.tensor(idx).unsqueeze(0) + for _ in range(seq_positions.shard_count) + ] + + transformer_block = [ + torch.full((bs, 1), transformer_block_index, device=device) + for _ in range(seq_positions.shard_count) + ] + + partitions = ReplicatedTensor(ts=partitions) + transformer_block = ReplicatedTensor(ts=transformer_block) + else: + partitions = torch.tensor(idx).unsqueeze(0) + transformer_block = torch.full( + (bs, 1), transformer_block_index, device=device + ) - # [bs, 1] - transformer_block = torch.full( - (bs, 1), transformer_block_index, device=device - ) partitions = partitions.repeat(bs, 1) indices = (page_id, transformer_block, partitions, page_offset) diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index b4b405dca..cdcdd8c2c 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -69,9 +69,14 @@ def add_model_options(parser: argparse.ArgumentParser): default="torch", choices=["decomposed", "torch"], ) + parser.add_argument( + "--skip-prefill", + help="Skips exporting prefill", + action="store_true", + ) parser.add_argument( "--skip-decode", - help="Enables prefill only, skips decode", + help="Skips export decode", action="store_true", ) From 062b1ae8a3358a222fb1f400c45d2ab19250d172 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Thu, 19 Dec 2024 13:04:36 -0800 Subject: [PATCH 30/39] Iterate on llama user guide. (#716) Progress on https://github.com/nod-ai/shark-ai/issues/691, trying to simplify a few steps before putting this into release notes for 3.1.0. * Add suggested `export/` directory to `.gitignore` (I'd prefer for the tools to default to a path in the user's homedir, but this is a less invasive change) * Remove `sharktank` from install instructions as it is included in `shark-ai` nightly releases now * Rework "Verify server" section to start with a health check then use `curl`. Keep the Python sample code for now, though similar projects usually also have a Python API for interfacing with LLMs. We also don't use a standardized HTTP API yet (like the OpenAI API). Maybe the SGLang integration will be more natural for users. --- .gitignore | 3 +- docs/shortfin/llm/user/e2e_llama8b_mi300x.md | 63 ++++++++++++-------- docs/user_guide.md | 1 + 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index bdb0b5387..daf8f6fda 100644 --- a/.gitignore +++ b/.gitignore @@ -33,12 +33,13 @@ wheelhouse # Local-only config options version_local.json -#Model artifacts +# Model artifacts *.pt *.safetensors *.gguf *.vmfb genfiles/ +export/ *.zip tmp/ diff --git a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md index 313a8086c..36ea817f2 100644 --- a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md +++ b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md @@ -24,25 +24,24 @@ source .venv/bin/activate ## Install stable shark-ai packages - +First install a torch version that fulfills your needs: ```bash -pip install shark-ai[apps] sharktank +# Fast installation of torch with just CPU support. +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu ``` -### Nightly packages +For other options, see https://pytorch.org/get-started/locally/. -To install nightly packages: - - +Next install shark-ai: ```bash -pip install shark-ai[apps] sharktank \ - --pre --find-links https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels +pip install shark-ai[apps] ``` -See also the -[instructions here](https://github.com/nod-ai/shark-ai/blob/main/docs/nightly_releases.md). +> [!TIP] +> To switch from the stable release channel to the nightly release channel, +> see [`nightly_releases.md`](../../../nightly_releases.md). ### Define a directory for export files @@ -192,25 +191,41 @@ cat shortfin_llm_server.log [2024-10-24 15:40:27.444] [info] [server.py:214] Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` -## Verify server +## Test the server -We can now verify our LLM server by sending a simple request: +We can now test our LLM server. -### Open python shell +First let's confirm that it is running: ```bash -python +curl -i http://localhost:8000/health + +# HTTP/1.1 200 OK +# date: Thu, 19 Dec 2024 19:40:43 GMT +# server: uvicorn +# content-length: 0 ``` -### Send request +Next, let's send a generation request: -```python -import requests +```bash +curl http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "text": "Name the capital of the United States.", + "sampling_params": {"max_completion_tokens": 50} + }' +``` + +### Send requests from Python +You can also send HTTP requests from Python like so: + +```python import os +import requests port = 8000 # Change if running on a different port - generate_url = f"http://localhost:{port}/generate" def generation_request(): @@ -225,16 +240,16 @@ def generation_request(): generation_request() ``` -After you receive the request, you can exit the python shell: +## Cleanup + +When done, you can stop the shortfin_llm_server by killing the process: ```bash -quit() +kill -9 $shortfin_process ``` -## Cleanup - -When done, you can kill the shortfin_llm_server by killing the process: +If you want to find the process again: ```bash -kill -9 $shortfin_process +ps -f | grep shortfin ``` diff --git a/docs/user_guide.md b/docs/user_guide.md index d3ef192e0..c4c3fdb58 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -17,6 +17,7 @@ Officially we support Python versions: 3.11, 3.12, 3.13 The rest of this guide assumes you are using Python 3.11. ### Install Python + To install Python 3.11 on Ubuntu: ```bash From 5e3f5e645abafc58c69a7f871e4b18b318d6653d Mon Sep 17 00:00:00 2001 From: IanNod <45800100+IanNod@users.noreply.github.com> Date: Thu, 19 Dec 2024 14:39:52 -0800 Subject: [PATCH 31/39] Adds sdxl's VAE decoder implementation (#653) --- .github/workflows/ci-sharktank.yml | 2 + sharktank/conftest.py | 9 + sharktank/requirements-tests.txt | 1 + sharktank/sharktank/models/punet/layers.py | 1 - sharktank/sharktank/models/vae/README.md | 28 ++ sharktank/sharktank/models/vae/config.py | 61 ++++ sharktank/sharktank/models/vae/layers.py | 261 ++++++++++++++++++ sharktank/sharktank/models/vae/model.py | 130 +++++++++ .../models/vae/tools/diffuser_ref.py | 56 ++++ .../sharktank/models/vae/tools/run_vae.py | 139 ++++++++++ .../sharktank/models/vae/tools/sample_data.py | 17 ++ sharktank/sharktank/ops/default_impls.py | 8 + sharktank/sharktank/ops/signatures.py | 20 ++ sharktank/sharktank/types/tensors.py | 5 + sharktank/tests/models/vae/vae_test.py | 202 ++++++++++++++ 15 files changed, 939 insertions(+), 1 deletion(-) create mode 100644 sharktank/sharktank/models/vae/README.md create mode 100644 sharktank/sharktank/models/vae/config.py create mode 100644 sharktank/sharktank/models/vae/layers.py create mode 100644 sharktank/sharktank/models/vae/model.py create mode 100644 sharktank/sharktank/models/vae/tools/diffuser_ref.py create mode 100644 sharktank/sharktank/models/vae/tools/run_vae.py create mode 100644 sharktank/sharktank/models/vae/tools/sample_data.py create mode 100644 sharktank/tests/models/vae/vae_test.py diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index cc39f9b6b..4cdd2b274 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -146,9 +146,11 @@ jobs: --with-clip-data \ --with-flux-data \ --with-t5-data \ + --with-vae-data \ sharktank/tests/models/clip/clip_test.py \ sharktank/tests/models/t5/t5_test.py \ sharktank/tests/models/flux/flux_test.py \ + sharktank/tests/models/vae/vae_test.py \ --durations=0 diff --git a/sharktank/conftest.py b/sharktank/conftest.py index d7118893a..5d16b5ff2 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -116,6 +116,15 @@ def pytest_addoption(parser): ), ) + parser.addoption( + "--with-vae-data", + action="store_true", + default=False, + help=( + "Enable tests that use vae data such as models not part of the source code." + ), + ) + # TODO: Remove all hardcoded paths in CI tests parser.addoption( "--llama3-8b-tokenizer-path", diff --git a/sharktank/requirements-tests.txt b/sharktank/requirements-tests.txt index d5b4b0c0e..a0ddf6117 100644 --- a/sharktank/requirements-tests.txt +++ b/sharktank/requirements-tests.txt @@ -2,3 +2,4 @@ datasets==3.0.0 parameterized pytest==8.0.0 pytest-html +diffusers diff --git a/sharktank/sharktank/models/punet/layers.py b/sharktank/sharktank/models/punet/layers.py index 9294deee2..f9b4ad8fd 100644 --- a/sharktank/sharktank/models/punet/layers.py +++ b/sharktank/sharktank/models/punet/layers.py @@ -571,7 +571,6 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tenso hidden_states = ops.elementwise(self.nonlinearity, hidden_states) hidden_states = self.conv1(hidden_states) - assert self.time_emb_proj is not None if self.time_emb_proj is not None: temb = ops.elementwise(self.nonlinearity, temb) temb = self.time_emb_proj(temb)[:, :, None, None] diff --git a/sharktank/sharktank/models/vae/README.md b/sharktank/sharktank/models/vae/README.md new file mode 100644 index 000000000..c4b8934c3 --- /dev/null +++ b/sharktank/sharktank/models/vae/README.md @@ -0,0 +1,28 @@ +# VAE decoder + +This is vae implemented in the style used for SDXL and referenced from diffusers implementation. + +## Preparing dataset +If not sharding or quantizing, the official model can be imported as from huggingface: + +``` +model_dir=$(huggingface-cli download \ + stabilityai/stable-diffusion-xl-base-1.0 \ + vae/config.json vae/diffusion_pytorch_model.safetensors) + +python -m sharktank.models.punet.tools.import_hf_dataset \ + --params $model_dir/vae/diffusion_pytorch_model.safetensors + --config-json $model_dir/vae/config.json --output-irpa-file ~/models/vae.irpa +``` + +# Run Vae decoder model eager mode +``` +python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu +``` + +## License + +Significant portions of this implementation were derived from diffusers, +licensed under Apache2: https://github.com/huggingface/diffusers +While much was a simple reverse engineering of the config.json and parameters, +code was taken where appropriate. diff --git a/sharktank/sharktank/models/vae/config.py b/sharktank/sharktank/models/vae/config.py new file mode 100644 index 000000000..9ee2f0427 --- /dev/null +++ b/sharktank/sharktank/models/vae/config.py @@ -0,0 +1,61 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# Significant portions of this implementation were derived from diffusers, +# licensed under Apache2: https://github.com/huggingface/diffusers +# While much was a simple reverse engineering of the config.json and parameters, +# code was taken where appropriate. + +from typing import List, Optional, Sequence, Tuple, Union + +from dataclasses import dataclass +import inspect +import warnings + +__all__ = [ + "HParams", +] + + +@dataclass +class HParams: + # Per block sequences. These are normalized from either an int (dubplicated + # to the number of down_blocks or a list. + layers_per_block: Tuple[int] + + act_fn: str = "silu" + block_out_channels: Sequence[int] = (128, 256, 512, 512) + in_channels: int = 3 + up_block_types: Sequence[str] = ( + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ) + layers_per_block: int = 2 + norm_num_groups: int = 32 + scaling_factor: float = 0.13025 + + def assert_default_values(self, attr_names: Sequence[str]): + for name in attr_names: + actual = getattr(self, name) + required = getattr(HParams, name) + if actual != required: + raise ValueError( + f"NYI: HParams.{name} != {required!r} (got {actual!r})" + ) + + @classmethod + def from_dict(cls, d: dict): + if "layers_per_block" not in d: + d["layers_per_block"] = 2 + + allowed = inspect.signature(cls).parameters + declared_kwargs = {k: v for k, v in d.items() if k in allowed} + extra_kwargs = [k for k in d.keys() if k not in allowed] + if extra_kwargs: + warnings.warn(f"Unhandled vae.HParams: {extra_kwargs}") + return cls(**declared_kwargs) diff --git a/sharktank/sharktank/models/vae/layers.py b/sharktank/sharktank/models/vae/layers.py new file mode 100644 index 000000000..0d7033f4a --- /dev/null +++ b/sharktank/sharktank/models/vae/layers.py @@ -0,0 +1,261 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional, Sequence, Tuple + +import math + +import torch +import torch.nn as nn + +from sharktank import ops +from sharktank.layers import * +from sharktank.types import * +from sharktank.models.punet.layers import ( + ResnetBlock2D, + Upsample2D, + GroupNormLayer, + AttentionLayer, +) +from .config import * + + +__all__ = ["UNetMidBlock2D", "UpDecoderBlock2D", "AttentionLayer"] + +# TODO Remove and integrate with punet AttentionLayer +class AttentionLayer(ThetaLayer): + def __init__( + self, + theta: Theta, + heads: int, # in_channels // attention_head_dim + dim_head, + rescale_output_factor: float, + eps: float, + norm_num_groups: int, + residual_connection: bool, + ): + super().__init__(theta) + self.heads = heads + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + + if norm_num_groups is not None: + self.group_norm = GroupNormLayer( + theta("group_norm"), num_groups=norm_num_groups, eps=eps + ) + else: + self.group_norm = None + + self.norm_q = None + self.norm_k = None + + self.norm_cross = None + self.to_q = LinearLayer(theta("to_q")) + self.to_k = LinearLayer(theta("to_k")) + self.to_v = LinearLayer(theta("to_v")) + + self.added_proj_bias = True + self.to_out = LinearLayer(theta("to_out")(0)) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = self.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + hidden_states = ops.scaled_dot_product_attention( + query, key, value, a=attention_mask + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, self.heads * head_dim + ) + + # linear proj + hidden_states = self.to_out(hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if self.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / self.rescale_output_factor + return hidden_states + + +class UpDecoderBlock2D(ThetaLayer): + def __init__( + self, + theta: Theta, + *, + num_layers: int, + resnet_eps: float, + resnet_act_fn: str, + resnet_groups: int, + resnet_out_scale_factor: Optional[float], + resnet_time_scale_shift: str, + temb_channels: int, + dropout: float, + add_upsample: bool, + ): + super().__init__(theta) + resnets = [] + + for i in range(num_layers): + resnets.append( + ResnetBlock2D( + theta("resnets")(i), + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + output_scale_factor=resnet_out_scale_factor, + time_embedding_norm=resnet_time_scale_shift, + temb_channels=temb_channels, + dropout=dropout, + ) + ) + self.resnets = nn.ModuleList(resnets) + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(theta("upsamplers")("0"), padding=1)] + ) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, + ) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + return hidden_states + + +class UNetMidBlock2D(ThetaLayer): + def __init__( + self, + theta: Theta, + temb_channels: int, + dropout: float, + num_layers: int, + resnet_eps: float, + resnet_time_scale_shift: str, + resnet_act_fn: str, + resnet_groups: int, + resnet_pre_norm: bool, + add_attention: bool, + attention_head_dim: int, + output_scale_factor: float, + attn_groups: Optional[int] = None, + ): + super().__init__(theta) + + resnet_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + # TODO Implement ResnetBlockCondNorm2d block for spatial time scale shift + raise AssertionError(f"ResnetBlockCondNorm2d not yet implemented") + else: + resnets = [ + ResnetBlock2D( + theta("resnets")(0), + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + time_embedding_norm=resnet_time_scale_shift, + temb_channels=temb_channels, + dropout=dropout, + ) + ] + # TODO: loop through num_layers properly. Works for sdxl vae specifically but removed for export reasons + if add_attention: + self.attention = AttentionLayer( + theta("attentions")(0), + heads=1, + dim_head=attention_head_dim, + rescale_output_factor=1.0, + eps=resnet_eps, + norm_num_groups=attn_groups, + residual_connection=True, + ) + else: + self.attention = None + + if resnet_time_scale_shift == "spatial": + # TODO Implement ResnetBlock2D for spatial time scale shift support + raise AssertionError( + f"ResnetBlock2D spatial time scale shift not yet implemented" + ) + else: + resnets.append( + ResnetBlock2D( + theta("resnets")(1), + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + time_embedding_norm=resnet_time_scale_shift, + temb_channels=temb_channels, + dropout=dropout, + ) + ) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None + ) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + if self.attention is not None: + hidden_states = self.attention(hidden_states) + hidden_states = self.resnets[1](hidden_states, temb) + return hidden_states diff --git a/sharktank/sharktank/models/vae/model.py b/sharktank/sharktank/models/vae/model.py new file mode 100644 index 000000000..1054108c7 --- /dev/null +++ b/sharktank/sharktank/models/vae/model.py @@ -0,0 +1,130 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +import torch.nn as nn + +from sharktank.layers import * +from sharktank import ops +from ...types import * + +from .config import * +from .layers import * +from sharktank.models.punet.layers import UpDownBlock2D, GroupNormLayer +from typing import Optional + + +class VaeDecoderModel(ThetaLayer): + @classmethod + def from_dataset(cls, ds: Dataset) -> "VaeDecoderModel": + hp = HParams.from_dict(ds.properties["hparams"]) + return cls(hp, ds.root_theta) + + def __init__(self, hp: HParams, theta: Theta): + super().__init__(theta) + self.hp = hp + + # input conv + self.post_quant_conv = Conv2DLayer(theta("post_quant_conv"), padding=(0, 0)) + self.conv_in = Conv2DLayer(theta("decoder")("conv_in"), padding=(1, 1)) + # Mid + self.mid_block = self._create_mid_block(theta("decoder")("mid_block")) + # up + self.up_blocks = nn.ModuleList([]) + self.upscale_dtype = theta("decoder")("up_blocks")(0)("resnets")(0)("conv1")( + "weight" + ).dtype + for i, up_block_name in enumerate(hp.up_block_types): + up_block_theta = theta("decoder")("up_blocks")(i) + is_final_block = i == len(hp.block_out_channels) - 1 + self.up_blocks.append( + self._create_up_block( + up_block_theta, + up_block_name, + is_final_block=is_final_block, + ) + ) + # TODO add spatial norm type support + self.conv_norm_out = GroupNormLayer( + theta("decoder")("conv_norm_out"), num_groups=hp.norm_num_groups, eps=1e-6 + ) + + self.conv_act = nn.SiLU() + self.conv_out = Conv2DLayer(theta("decoder")("conv_out"), padding=(1, 1)) + + def forward( + self, sample: torch.Tensor, latent_embeds: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + The forward method of the 'Decoder' class + Args: + sample ('torch.Tensor') input latents of shape (batch_size, num_channels, height, width) + + """ + self.trace_goldens( + "inputs", + { + "sample": sample, + "latent_embeds": latent_embeds, + }, + ) + sample = 1 / self.hp.scaling_factor * sample + + sample = self.post_quant_conv(sample) + sample = self.conv_in(sample) + self.trace_golden("conv_in", sample) + # TODO add training and gradient checkpointing support + sample = self.mid_block(sample, latent_embeds) + self.trace_golden("mid_block", sample) + + sample = sample.to(self.upscale_dtype) + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + + sample = self.conv_act(sample) + sample = self.conv_out(sample) + sample = (sample / 2 + 0.5).clamp(0, 1) + return sample + + def _create_mid_block(self, mid_block_theta: Theta) -> nn.Module: + hp = self.hp + return UNetMidBlock2D( + mid_block_theta, + temb_channels=None, + dropout=0.0, + num_layers=hp.layers_per_block, + resnet_eps=1e-6, + resnet_act_fn="swish", + resnet_groups=hp.norm_num_groups, + attn_groups=hp.norm_num_groups, + resnet_pre_norm=True, + add_attention=True, + attention_head_dim=hp.block_out_channels[-1], + output_scale_factor=1.0, + resnet_time_scale_shift="default", + ) + + def _create_up_block( + self, up_block_theta: Theta, type_name: str, is_final_block: bool + ) -> nn.Module: + hp = self.hp + if type_name == "UpDecoderBlock2D": + return UpDecoderBlock2D( + up_block_theta, + num_layers=hp.layers_per_block + 1, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=hp.act_fn, + resnet_groups=hp.norm_num_groups, + resnet_time_scale_shift="default", + temb_channels=None, + dropout=0.0, + resnet_out_scale_factor=None, + ) diff --git a/sharktank/sharktank/models/vae/tools/diffuser_ref.py b/sharktank/sharktank/models/vae/tools/diffuser_ref.py new file mode 100644 index 000000000..c2c283197 --- /dev/null +++ b/sharktank/sharktank/models/vae/tools/diffuser_ref.py @@ -0,0 +1,56 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import torch +from diffusers import AutoencoderKL + + +class VaeModel(torch.nn.Module): + def __init__( + self, + hf_model_name, + custom_vae="", + ): + super().__init__() + self.vae = None + if custom_vae in ["", None]: + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + elif "safetensors" in custom_vae: + custom_vae = safetensors.torch.load_file(custom_vae) + # custom vae as a HF state dict + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + self.vae.load_state_dict(custom_vae) + elif not isinstance(custom_vae, dict): + try: + # custom HF repo with no vae subfolder + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + ) + except: + # some larger repo with vae subfolder + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + subfolder="vae", + ) + + def decode(self, inp): + # The reference vae decode does not do scaling and leaves it for the sdxl pipeline. We integrate it into vae for pipeline performance so using the hardcoded values from the config.json here + img = 1 / 0.13025 * inp + x = self.vae.decode(img, return_dict=False)[0] + return (x / 2 + 0.5).clamp(0, 1) + + +def run_torch_vae(hf_model_name, example_input): + + vae_model = VaeModel(hf_model_name) + return vae_model.decode(example_input) diff --git a/sharktank/sharktank/models/vae/tools/run_vae.py b/sharktank/sharktank/models/vae/tools/run_vae.py new file mode 100644 index 000000000..540436fd1 --- /dev/null +++ b/sharktank/sharktank/models/vae/tools/run_vae.py @@ -0,0 +1,139 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pathlib import Path +import sys + +import torch + +from iree.turbine import aot + +from ..model import VaeDecoderModel +from ....utils.patching import SaveModuleResultTensorsPatch + +from .sample_data import get_random_inputs +from sharktank.models.punet.tools.sample_data import load_inputs, save_outputs +from iree.turbine.aot import FxProgramsBuilder, export, decompositions + +from iree.turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) + + +def export_vae(model, sample_inputs, decomp_attn): + decomp_list = [] + if decomp_attn: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, add_ops=decomp_list + ): + + fxb = FxProgramsBuilder(model) + + @fxb.export_program( + name=f"forward", + args=tuple(torch.unsqueeze(sample_inputs, 0)), + strict=False, + ) + def _( + model, + sample_inputs, + ): + return model(sample_inputs) + + output = export(fxb, import_symbolic_shape_expressions=True) + return output + + +def main(argv): + from ....utils import cli + + parser = cli.create_parser() + cli.add_input_dataset_options(parser) + parser.add_argument("--device", default="cuda:0", help="Torch device to run on") + parser.add_argument("--dtype", default="float16", help="DType to run in") + parser.add_argument("--export", type=Path, help="Export to path (vs run)") + parser.add_argument("--bs", default=1, type=int, help="Batch size for export") + parser.add_argument( + "--inputs", + type=Path, + help="Safetensors file of inputs (or random if not given)", + ) + parser.add_argument( + "--outputs", + type=Path, + help="Safetensors file of outputs", + ) + parser.add_argument( + "--save-intermediates-path", + type=Path, + help="Path of safetensors file in which to save all module outputs", + ) + parser.add_argument( + "--compare_vs_torch", + action="store_true", + help="Compares results vs HF diffusers reference model", + ) + parser.add_argument( + "--decomp_attn", + action="store_true", + help="Decomposes the attention op during export", + ) + args = cli.parse(parser, args=argv) + + device = args.device + dtype = getattr(torch, args.dtype) + + ds = cli.get_input_dataset(args) + ds.to(device=device) + + mdl = VaeDecoderModel.from_dataset(ds) + + # Run a step for debugging. + if args.inputs: + inputs = load_inputs(args.inputs, dtype=dtype, device=device, bs=args.bs) + else: + inputs = get_random_inputs(dtype=dtype, device=device, bs=args.bs) + + if args.export: + # TODO move export from a run_vae file + output = export_vae(mdl, inputs, args.decomp_attn) + output.save_mlir(args.export) + print("exported VAE model. Skipping eager execution") + else: + # Save intermediates. + intermediates_saver = None + if args.save_intermediates_path: + intermediates_saver = SaveModuleResultTensorsPatch() + intermediates_saver.patch_child_modules(mdl.cond_model) + + results = mdl.forward(inputs) + print("results:", results) + + if args.outputs: + print(f"Saving outputs to {args.outputs}") + save_outputs(args.outputs, results) + + if intermediates_saver: + print(f"Saving intermediate tensors to: {args.save_intermediates_path}") + intermediates_saver.save_file(args.save_intermediates_path) + + if args.compare_vs_torch: + from .diffuser_ref import run_torch_vae + + diffusers_results = run_torch_vae( + "stabilityai/stable-diffusion-xl-base-1.0", inputs + ) + print("diffusers results:", diffusers_results) + torch.testing.assert_close(diffusers_results, results) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/sharktank/sharktank/models/vae/tools/sample_data.py b/sharktank/sharktank/models/vae/tools/sample_data.py new file mode 100644 index 000000000..cd946088e --- /dev/null +++ b/sharktank/sharktank/models/vae/tools/sample_data.py @@ -0,0 +1,17 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Various utilities for deriving stable sample data for the model.""" + +from pathlib import Path + +import torch + + +def get_random_inputs(dtype, device, bs: int = 2): + height = 1024 + width = 1024 + return torch.rand(bs, 4, width // 8, height // 8, dtype=dtype).to(device) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index d66e97233..bbead9572 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -476,6 +476,14 @@ def unsqueeze_default(tensor: Union[Tensor, PrimitiveTensor], dim: int) -> Tenso return torch.unsqueeze(tensor, dim) +@squeeze.override(AllOfType(AnyTensor, PrimitiveTensor)) +def squeeze_default(tensor, dim: Optional[int] = None) -> AnyTensor: + if dim is None: + return torch.squeeze(unbox_tensor(tensor)) + else: + return torch.squeeze(unbox_tensor(tensor), dim) + + @view.override(Tensor) def view_default(tensor: Union[Tensor, PrimitiveTensor], shape: List[int]) -> Tensor: return unbox_tensor(tensor).view(*shape) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index cbe959d28..00a76b42b 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -55,6 +55,7 @@ "sharded_cat", "sharded_sum", "softmax", + "squeeze", "to", "transfer_to_logical_device", "transpose", @@ -1104,6 +1105,25 @@ def _unsqueeze_trampoline( d.fail(tensors) +@overridable +def squeeze(tensor, dim: Optional[int]) -> AnyTensor: + """See torch.squeeze""" + ... + + +@squeeze.trampoline +def _squeeze_trampoline( + d: SignatureDispatcher, tensor, dim: Optional[int] +) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensor): + result = override(tensor, dim) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def view(tensor: AnyTensor, shape: List[int]) -> AnyTensor: """See torch.Tensor.view""" diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 153a5d753..9470aba7e 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -383,6 +383,11 @@ def size(self, dim: Optional[int] = None) -> tuple[int]: return tuple(self.shape) return self.shape[dim] + def squeeze(self, dim: Optional[int] = None) -> "AnyTensor": + from ..ops import squeeze + + return squeeze(self, dim) + def transpose(self, dim0: int, dim1: int) -> "AnyTensor": from ..ops import transpose diff --git a/sharktank/tests/models/vae/vae_test.py b/sharktank/tests/models/vae/vae_test.py new file mode 100644 index 000000000..99454f6cf --- /dev/null +++ b/sharktank/tests/models/vae/vae_test.py @@ -0,0 +1,202 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pathlib import Path +import sys + +import torch + +from iree.turbine import aot + +from sharktank.types import Dataset +from sharktank.models.vae.model import VaeDecoderModel +from sharktank.models.vae.tools.diffuser_ref import run_torch_vae +from sharktank.models.vae.tools.run_vae import export_vae +from sharktank.models.vae.tools.sample_data import get_random_inputs + +from sharktank.models.punet.tools.sample_data import load_inputs, save_outputs +from sharktank.tools.import_hf_dataset import import_hf_dataset +from iree.turbine.aot import FxProgramsBuilder, export, decompositions +from sharktank.utils.hf_datasets import get_dataset +import unittest +import pytest +from huggingface_hub import hf_hub_download +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, + flatten_for_iree_signature, + iree_to_torch, +) +import iree.compiler +from collections import OrderedDict + +with_vae_data = pytest.mark.skipif("not config.getoption('with_vae_data')") + + +@with_vae_data +class VaeSDXLDecoderTest(unittest.TestCase): + def setUp(self): + hf_model_id = "stabilityai/stable-diffusion-xl-base-1.0" + hf_hub_download( + repo_id=hf_model_id, + local_dir="sdxl_vae", + local_dir_use_symlinks=False, + revision="main", + filename="vae/config.json", + ) + hf_hub_download( + repo_id=hf_model_id, + local_dir="sdxl_vae", + local_dir_use_symlinks=False, + revision="main", + filename="vae/diffusion_pytorch_model.safetensors", + ) + hf_hub_download( + repo_id="amd-shark/sdxl-quant-models", + local_dir="sdxl_vae", + local_dir_use_symlinks=False, + revision="main", + filename="vae/vae.safetensors", + ) + torch.manual_seed(12345) + f32_dataset = import_hf_dataset( + "sdxl_vae/vae/config.json", + ["sdxl_vae/vae/diffusion_pytorch_model.safetensors"], + ) + f32_dataset.save("sdxl_vae/vae_f32.irpa", io_report_callback=print) + f16_dataset = import_hf_dataset( + "sdxl_vae/vae/config.json", ["sdxl_vae/vae/vae.safetensors"] + ) + f16_dataset.save("sdxl_vae/vae_f16.irpa", io_report_callback=print) + + def testCompareF32EagerVsHuggingface(self): + dtype = getattr(torch, "float32") + inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1) + ref_results = run_torch_vae("sdxl_vae", inputs) + + ds = Dataset.load("sdxl_vae/vae_f32.irpa", file_type="irpa") + model = VaeDecoderModel.from_dataset(ds).to(device="cpu") + + results = model.forward(inputs) + + torch.testing.assert_close(ref_results, results) + + @pytest.mark.skip(reason="running fp16 on cpu is extremely slow") + def testCompareF16EagerVsHuggingface(self): + dtype = getattr(torch, "float32") + inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1) + ref_results = run_torch_vae("sdxl_vae", inputs) + + ds = Dataset.load("sdxl_vae/vae_f16.irpa", file_type="irpa") + model = VaeDecoderModel.from_dataset(ds).to(device="cpu") + + results = model.forward(inputs.to(torch.float16)) + + torch.testing.assert_close(ref_results, results) + + def testVaeIreeVsHuggingFace(self): + dtype = getattr(torch, "float32") + inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1) + ref_results = run_torch_vae("sdxl_vae", inputs) + + ds_f16 = Dataset.load("sdxl_vae/vae_f16.irpa", file_type="irpa") + ds_f32 = Dataset.load("sdxl_vae/vae_f32.irpa", file_type="irpa") + + model_f16 = VaeDecoderModel.from_dataset(ds_f16).to(device="cpu") + model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu") + + # TODO: Decomposing attention due to https://github.com/iree-org/iree/issues/19286, remove once issue is resolved + module_f16 = export_vae(model_f16, inputs.to(torch.float16), True) + module_f32 = export_vae(model_f32, inputs, True) + + module_f16.save_mlir("sdxl_vae/vae_f16.mlir") + module_f32.save_mlir("sdxl_vae/vae_f32.mlir") + extra_args = [ + "--iree-hal-target-backends=rocm", + "--iree-hip-target=gfx942", + "--iree-opt-const-eval=false", + "--iree-opt-strip-assertions=true", + "--iree-global-opt-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-hip-waves-per-eu=2", + "--iree-dispatch-creation-enable-aggressive-fusion=true", + "--iree-codegen-llvmgpu-use-vector-distribution=true", + "--iree-execution-model=async-external", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)", + ] + + iree.compiler.compile_file( + "sdxl_vae/vae_f16.mlir", + output_file="sdxl_vae/vae_f16.vmfb", + extra_args=extra_args, + ) + iree.compiler.compile_file( + "sdxl_vae/vae_f32.mlir", + output_file="sdxl_vae/vae_f32.vmfb", + extra_args=extra_args, + ) + + iree_devices = get_iree_devices(driver="hip", device_count=1) + + iree_module, iree_vm_context, iree_vm_instance = load_iree_module( + module_path="sdxl_vae/vae_f16.vmfb", + devices=iree_devices, + parameters_path="sdxl_vae/vae_f16.irpa", + ) + + input_args = OrderedDict([("inputs", inputs.to(torch.float16))]) + iree_args = flatten_for_iree_signature(input_args) + + iree_args = prepare_iree_module_function_args( + args=iree_args, devices=iree_devices + ) + iree_result = run_iree_module_function( + module=iree_module, + vm_context=iree_vm_context, + args=iree_args, + driver="hip", + function_name="forward", + )[0].to_host() + # TODO: Verify these numerics are good or if tolerances are too loose + # TODO: Upload IR on passing tests to keep https://github.com/iree-org/iree/blob/main/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py at latest + torch.testing.assert_close( + ref_results.to(torch.float16), + torch.from_numpy(iree_result), + atol=5e-2, + rtol=4e-1, + ) + + iree_module, iree_vm_context, iree_vm_instance = load_iree_module( + module_path="sdxl_vae/vae_f32.vmfb", + devices=iree_devices, + parameters_path="sdxl_vae/vae_f32.irpa", + ) + + input_args = OrderedDict([("inputs", inputs)]) + iree_args = flatten_for_iree_signature(input_args) + + iree_args = prepare_iree_module_function_args( + args=iree_args, devices=iree_devices + ) + iree_result = run_iree_module_function( + module=iree_module, + vm_context=iree_vm_context, + args=iree_args, + driver="hip", + function_name="forward", + )[0].to_host() + # TODO: Upload IR on passing tests + torch.testing.assert_close( + ref_results, torch.from_numpy(iree_result), atol=3e-5, rtol=6e-6 + ) + + +if __name__ == "__main__": + unittest.main() From 7862ff8aef1cbc0ab5ceea48afebabef00402c09 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Thu, 19 Dec 2024 15:15:36 -0800 Subject: [PATCH 32/39] Iterate on llama user guide [2]. (#718) Progress on https://github.com/nod-ai/shark-ai/issues/691. * Generalize guide to any llama model on any accelerator, mentioning specifics of platform/model support where it matters * Restructure the introduction section with more context and an overview of the rest of the guide (more still to do here, explaining this tech stack) * Add prerequisites section, modeled after the [user guide](https://github.com/nod-ai/shark-ai/blob/main/docs/user_guide.md) * Add more "why" explanations for many steps (more still to do here) * Start trimming environment variable instructions --- ..._llama8b_mi300x.md => llama_end_to_end.md} | 163 ++++++++++-------- .../shortfin_with_sglang_frontend_language.md | 2 +- 2 files changed, 92 insertions(+), 73 deletions(-) rename docs/shortfin/llm/user/{e2e_llama8b_mi300x.md => llama_end_to_end.md} (51%) diff --git a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md b/docs/shortfin/llm/user/llama_end_to_end.md similarity index 51% rename from docs/shortfin/llm/user/e2e_llama8b_mi300x.md rename to docs/shortfin/llm/user/llama_end_to_end.md index 36ea817f2..a74851407 100644 --- a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md +++ b/docs/shortfin/llm/user/llama_end_to_end.md @@ -1,39 +1,53 @@ -# LLama 8b GPU instructions on MI300X +# Llama end to end serving instructions -## Setup +## Introduction -We will use an example with `llama_8b_f16` in order to describe the -process of exporting a model for use in the shortfin llm server with an -MI300 GPU. +This guide demonstrates how to serve the +[Llama family](https://www.llama.com/) of Large Language Models (LLMs) using +shark-ai. -### Pre-Requisites +* By the end of this guide you will have a server running locally and you will + be able to send HTTP requests containing chat prompts and receive chat + responses back. -- Python >= 3.11 is recommended for this flow - - You can check out [pyenv](https://github.com/pyenv/pyenv) - as a good tool to be able to manage multiple versions of python - on the same system. +* We will demonstrate the development flow using a version of the + [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) + model, quantized to fp16. Other models in the + [Llama 3.1 family](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f) + are supported as well. -### Create virtual environment +Overview: -To start, create a new virtual environment: +1. Setup, installing dependencies and configuring the environment +2. Download model files then compile the model for our accelerator(s) of choice +3. Start a server using the compiled model files +4. Send chat requests to the server and receive chat responses back -```bash -python -m venv --prompt shark-ai .venv -source .venv/bin/activate -``` +## 1. Setup + +### Pre-requisites -## Install stable shark-ai packages +- An installed + [AMD Instinct™ MI300X Series Accelerator](https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html) + - Other accelerators should work too, but shark-ai is currently most + optimized on MI300X +- Compatible versions of Linux and ROCm (see the [ROCm compatability matrix](https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html)) +- Python >= 3.11 -First install a torch version that fulfills your needs: +### Create virtual environment + +To start, create a new +[virtual environment](https://docs.python.org/3/library/venv.html): ```bash -# Fast installation of torch with just CPU support. -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +python -m venv --prompt shark-ai .venv +source .venv/bin/activate ``` -For other options, see https://pytorch.org/get-started/locally/. +### Install Python packages -Next install shark-ai: +Install `shark-ai`, which includes the `sharktank` model development toolkit +and the `shortfin` serving framework: ```bash pip install shark-ai[apps] @@ -43,17 +57,28 @@ pip install shark-ai[apps] > To switch from the stable release channel to the nightly release channel, > see [`nightly_releases.md`](../../../nightly_releases.md). -### Define a directory for export files +The `sharktank` project contains implementations of popular LLMs optimized for +ahead of time compilation and serving via `shortfin`. These implementations are +built using PyTorch, so install a `torch` version that fulfills your needs by +following either https://pytorch.org/get-started/locally/ or our recommendation: -Create a new directory for us to export files like -`model.mlir`, `model.vmfb`, etc. +```bash +# Fast installation of torch with just CPU support. +pip install torch --index-url https://download.pytorch.org/whl/cpu +``` + +### Prepare a working directory + +Create a new directory for model files and compilation artifacts: ```bash -mkdir $PWD/export export EXPORT_DIR=$PWD/export +mkdir -p $EXPORT_DIR ``` -### Download llama3_8b_fp16.gguf +## 2. Download and compile the model + +### Download `llama3_8b_fp16.gguf` We will use the `hf_datasets` module in `sharktank` to download a LLama3.1 8b f16 model. @@ -64,10 +89,10 @@ python -m sharktank.utils.hf_datasets llama3_8B_fp16 --local-dir $EXPORT_DIR ### Define environment variables -Define the following environment variables to make running -this example a bit easier: +We'll first define some environment variables that are shared between the +following steps. -#### Model/Tokenizer vars +#### Model/tokenizer variables This example uses the `llama8b_f16.gguf` and `tokenizer.json` files that were downloaded in the previous step. @@ -77,52 +102,41 @@ export MODEL_PARAMS_PATH=$EXPORT_DIR/meta-llama-3.1-8b-instruct.f16.gguf export TOKENIZER_PATH=$EXPORT_DIR/tokenizer.json ``` -#### General env vars +#### General environment variables -The following env vars can be copy + pasted directly: +These variables configure the model export and compilation process: ```bash -# Path to export model.mlir file export MLIR_PATH=$EXPORT_DIR/model.mlir -# Path to export config.json file export OUTPUT_CONFIG_PATH=$EXPORT_DIR/config.json -# Path to export model.vmfb file export VMFB_PATH=$EXPORT_DIR/model.vmfb -# Batch size for kvcache -export BS=1,4 +export EXPORT_BATCH_SIZES=1,4 # NOTE: This is temporary, until multi-device is fixed export ROCR_VISIBLE_DEVICES=1 ``` -## Export to MLIR +### Export to MLIR using sharktank -We will now use the `sharktank.examples.export_paged_llm_v1` script -to export our model to `.mlir` format. +We will now use the +[`sharktank.examples.export_paged_llm_v1`](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/examples/export_paged_llm_v1.py) +script to export an optimized implementation of the LLM from PyTorch to the +`.mlir` format that our compiler can work with: ```bash python -m sharktank.examples.export_paged_llm_v1 \ --gguf-file=$MODEL_PARAMS_PATH \ --output-mlir=$MLIR_PATH \ --output-config=$OUTPUT_CONFIG_PATH \ - --bs=$BS + --bs=$EXPORT_BATCH_SIZES ``` -## Compiling to `.vmfb` - -Now that we have generated a `model.mlir` file, -we can compile it to `.vmfb` format, which is required for running -the `shortfin` LLM server. +### Compile using IREE to a `.vmfb` file -We will use the +Now that we have generated a `model.mlir` file, we can compile it to the `.vmfb` +format, which is required for running the `shortfin` LLM server. We will use the [iree-compile](https://iree.dev/developers/general/developer-overview/#iree-compile) tool for compiling our model. -### Compile for MI300 - -**NOTE: This command is specific to MI300 GPUs. -For other `--iree-hip-target` GPU options, -look [here](https://iree.dev/guides/deployment-configurations/gpu-rocm/#compile-a-program)** - ```bash iree-compile $MLIR_PATH \ --iree-hal-target-backends=rocm \ @@ -130,26 +144,31 @@ iree-compile $MLIR_PATH \ -o $VMFB_PATH ``` -## Running the `shortfin` LLM server +> [!NOTE] +> The `--iree-hip-target=gfx942` option will generate code for MI300 series +> GPUs. To compile for other targets, see +> [the options here](https://iree.dev/guides/deployment-configurations/gpu-rocm/#compile-a-program). -We should now have all of the files that we need to run the shortfin LLM server. +### Check exported files -Verify that you have the following in your specified directory ($EXPORT_DIR): +We should now have all of the files that we need to run the shortfin LLM server: ```bash -ls $EXPORT_DIR +ls -1A $EXPORT_DIR ``` -- config.json -- meta-llama-3.1-8b-instruct.f16.gguf -- model.mlir -- model.vmfb -- tokenizer_config.json -- tokenizer.json +Expected output: -### Launch server +``` +config.json +meta-llama-3.1-8b-instruct.f16.gguf +model.mlir +model.vmfb +tokenizer_config.json +tokenizer.json +``` -#### Run the shortfin server +## 3. Run the `shortfin` LLM server Now that we are finished with setup, we can start the Shortfin LLM Server. @@ -184,18 +203,16 @@ when you see the following logs outputted to terminal: cat shortfin_llm_server.log ``` -#### Expected output +Expected output: -```text +``` [2024-10-24 15:40:27.440] [info] [on.py:62] Application startup complete. [2024-10-24 15:40:27.444] [info] [server.py:214] Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` -## Test the server - -We can now test our LLM server. +## 4. Test the server -First let's confirm that it is running: +We can now test our LLM server. First let's confirm that it is running: ```bash curl -i http://localhost:8000/health @@ -217,6 +234,8 @@ curl http://localhost:8000/generate \ }' ``` +The response should come back as `Washington, D.C.!`. + ### Send requests from Python You can also send HTTP requests from Python like so: @@ -242,7 +261,7 @@ generation_request() ## Cleanup -When done, you can stop the shortfin_llm_server by killing the process: +When done, you can stop the `shortfin_llm_server` by killing the process: ```bash kill -9 $shortfin_process diff --git a/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md index b63861a56..1292eba66 100644 --- a/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md +++ b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md @@ -35,7 +35,7 @@ For this tutorial, you will need to meet the following prerequisites: ## Install/Start `shortfin` LLM server -Follow the steps [here](https://github.com/nod-ai/shark-ai/blob/main/docs/shortfin/llm/user/e2e_llama8b_mi300x.md) +Follow the steps [here](https://github.com/nod-ai/shark-ai/blob/main/docs/shortfin/llm/user/llama_end_to_end.md) to export a model with `sharktank` and start a `shortfin` LLM server with that model. From fc9576b4b3e53e01462258bc5667b79cc560c591 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 19 Dec 2024 17:34:34 -0800 Subject: [PATCH 33/39] [llama] Added the fused rotary embedding kernel (#719) Reworked rotary embedding application to be performed via a custom kernel. This includes dropping `static_table` for the sake of maintenance (it was largely unused). It includes a simple numerical test however under the hood no numerical change should occur. Existing baseline vs hugging face remained unchanged. --- .../sharktank/examples/export_paged_llm_v1.py | 9 +- sharktank/sharktank/kernels/__init__.py | 1 + sharktank/sharktank/kernels/rotary.py | 70 +++++++++++++++ .../kernels/templates/rotary_embedding.mlir | 63 +++++++++++++ .../layers/paged_llama_attention_block.py | 2 +- .../sharktank/layers/rotary_embedding.py | 88 +++++++------------ sharktank/sharktank/models/llama/llama.py | 4 +- sharktank/tests/kernels/rotary.py | 31 +++++++ 8 files changed, 199 insertions(+), 69 deletions(-) create mode 100644 sharktank/sharktank/kernels/rotary.py create mode 100644 sharktank/sharktank/kernels/templates/rotary_embedding.mlir create mode 100644 sharktank/tests/kernels/rotary.py diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 312a53d33..f93e58e61 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -79,7 +79,6 @@ def main(): hp, tensor_parallelism_size=tensor_parallelism_size, use_hf=False, - static_tables=False, # Rely on the compiler for hoisting tables. kv_cache_type="direct" if args.bs == [1] else "paged", attention_kernel=args.attention_kernel, block_seq_stride=args.block_seq_stride, @@ -219,22 +218,16 @@ def _(model, tokens, seq_lens, seq_block_ids, cs): else: cache_tensors = cs - sl = tokens.shape[1] - input_mask = model.input_mask(seq_lens, sl) - attention_mask = model.attention_mask(input_mask) - if llama_config.tensor_parallelism_size != 1: shard_count = llama_config.tensor_parallelism_size tokens = ops.replicate(tokens, count=shard_count) - attention_mask = ops.replicate(attention_mask, count=shard_count) seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) - cache_tensors = repack_cache(cs, cache_shard_dim) logits = model.prefill( tokens, - attention_mask=attention_mask, + attention_mask=None, # We rely on causal attention seq_block_ids=seq_block_ids, cache_state=cache_tensors, ) diff --git a/sharktank/sharktank/kernels/__init__.py b/sharktank/sharktank/kernels/__init__.py index 445f44852..1b84f0bee 100644 --- a/sharktank/sharktank/kernels/__init__.py +++ b/sharktank/sharktank/kernels/__init__.py @@ -10,6 +10,7 @@ from .mmt_block_scaled_offset_q4 import * from .mmt_block_scaled_q8 import * from .mmt_super_block_scaled_offset_q4 import * +from .rotary import * from .batch_matmul_transpose_b import * from .conv_2d_nchw_fchw import * from .pooling_nchw_sum import * diff --git a/sharktank/sharktank/kernels/rotary.py b/sharktank/sharktank/kernels/rotary.py new file mode 100644 index 000000000..196fc32c2 --- /dev/null +++ b/sharktank/sharktank/kernels/rotary.py @@ -0,0 +1,70 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from sharktank.kernels.base import * + +__all__ = [ + "apply_rotary_embedding", +] + + +@CustomOp.register(library=LIBRARY) +class apply_rotary_embedding(CustomOp): + + signature = "apply_rotary_embedding(Tensor input, Tensor table) -> (Tensor)" + + def select(self, ksel: KernelSelection): + inputs_desc = ksel.arg_tensor(0) + table_desc = ksel.arg_tensor(1) + out_desc = ksel.return_new_tensor( + inputs_desc.t.shape, dtype=inputs_desc.t.dtype + ) + specialize_all_known_dims(inputs_desc) + specialize_all_known_dims(table_desc) + specialize_all_known_dims(out_desc) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + + input = kb.arg_value(0) + table = kb.arg_value(1) + + input_tensor_type = RankedTensorType(input.type) + table_tensor_type = RankedTensorType(table.type) + + input_asm_type, input_ident, input_dtype = unpack_tensor_type(input.type) + table_asm_type, table_ident, table_dtype = unpack_tensor_type(table.type) + + assert input_dtype == table_dtype + + # Generate specialization signature and types. + bs = input.type.shape[0] + sl = input.type.shape[1] + sl = "D" if sl < 0 else sl + heads = input.type.shape[2] + dims = input.type.shape[3] + + template_file = "rotary_embedding.mlir" + target_function_name = ( + f"sharktank_rotary_embedding_{bs}_{sl}_{heads}_{dims}_{input_dtype}" + ) + + # Template params. + input_tensor_type = input_asm_type + table_tensor_type = table_asm_type + + target_function = inline_template_function( + kb, + template_file, + target_function_name, + input_tensor_type=input_tensor_type, + table_tensor_type=table_tensor_type, + bs=bs, + sl=sl, + heads=heads, + dims=dims, + dtype=str(input_dtype), + ) + kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/rotary_embedding.mlir b/sharktank/sharktank/kernels/templates/rotary_embedding.mlir new file mode 100644 index 000000000..adec6805b --- /dev/null +++ b/sharktank/sharktank/kernels/templates/rotary_embedding.mlir @@ -0,0 +1,63 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +!input_tensor_type = {{input_tensor_type}} +!table_tensor_type = {{table_tensor_type}} + +module { + +util.func private @sharktank_rotary_embedding_{{bs}}_{{sl}}_{{heads}}_{{dims}}_{{dtype}}(%input: !input_tensor_type, %table: !table_tensor_type) -> !input_tensor_type { + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + + %d0 = tensor.dim %input, %c0 : !input_tensor_type + %d1 = tensor.dim %input, %c1 : !input_tensor_type + %d2 = tensor.dim %input, %c2 : !input_tensor_type + %d3 = tensor.dim %input, %c3 : !input_tensor_type + + %empty_dyn = tensor.empty(%d0, %d1, %d2, %d3) : tensor + %empty = tensor.cast %empty_dyn : tensor to {{input_tensor_type}} + + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%table : !table_tensor_type ) + outs(%empty : !input_tensor_type) { + ^bb0(%b0 : {{dtype}} , %b1 : {{dtype}}): + %0 = linalg.index 0 : index + %1 = linalg.index 1 : index + %2 = linalg.index 2 : index + %3 = linalg.index 3 : index + %div = arith.divui %3, %c2 : index + %mod = arith.remui %3, %c2 : index + %a_cosb = math.cos %b0 : {{dtype}} + %a_sinb = math.sin %b0 : {{dtype}} + %real_index = arith.muli %div, %c2 : index + %imag_index = arith.addi %real_index, %c1 : index + %real = tensor.extract %input[%0, %1, %2, %real_index] : !input_tensor_type + %imag = tensor.extract %input[%0, %1, %2, %imag_index] : !input_tensor_type + %cmp = arith.cmpi eq, %mod, %c0 : index + %real_t0 = arith.mulf %real, %a_cosb : {{dtype}} + %real_t1 = arith.mulf %imag, %a_sinb : {{dtype}} + %real_t2 = arith.subf %real_t0, %real_t1 : {{dtype}} + %imag_t0 = arith.mulf %imag, %a_cosb : {{dtype}} + %imag_t1 = arith.mulf %real, %a_sinb : {{dtype}} + %imag_t2 = arith.addf %imag_t0, %imag_t1 : {{dtype}} + %val = arith.select %cmp, %real_t2, %imag_t2 : {{dtype}} + linalg.yield %val : {{dtype}} + } -> !input_tensor_type + + util.return %result : !input_tensor_type +} + +} diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 6bd33c93f..d74e2a92d 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -221,7 +221,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: k=keys, # [bs, ..., sl, dim] v=values, # [bs, ..., sl, dim] a=attention_mask, # [bs, ..., sl, sl] - is_causal=False, # assumes causal masking when true + is_causal=attention_mask is None, # assumes causal masking when true scale=None, # defaults to 1/sqrt(dim) ) diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 99ecf5057..623c02ea6 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -11,6 +11,7 @@ from .base import BaseLayer from .. import ops +from .. import kernels from ..types import SplitPrimitiveTensor, ReplicatedTensor, unbox_tensor @@ -25,7 +26,6 @@ def __init__( rope_freq_base: Optional[float], device: Optional[torch.device] = None, use_hf: bool = False, - static_tables: bool = False, use_table: bool = True, tensor_parallelism_size: int = 1, ): @@ -34,26 +34,14 @@ def __init__( self.rope_dimension_count = rope_dimension_count self.max_seqlen = max_seqlen self.use_hf = use_hf - self.static_tables = static_tables self.use_table = use_table self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0 self.tensor_parallelism_size = tensor_parallelism_size - if static_tables: - ops.module_register_buffer( - self, "static_rotary_embed_table", self._create_rotary_embed_table() - ) - else: - self.static_rotary_embed_table = None @property def rotary_embed_table(self): - if self.use_table: - if self.static_tables: - return self.static_rotary_embed_table - return self._create_rotary_embed_table() - - return None + return self._create_rotary_embed_table() def forward( self, @@ -61,33 +49,29 @@ def forward( xt: Union[torch.Tensor, SplitPrimitiveTensor], start_index: int, ): - if isinstance(xt, SplitPrimitiveTensor): - rotary_shards = [None] * xt.shard_count - if self.rotary_embed_table is not None: - assert ( - isinstance(self.rotary_embed_table, ReplicatedTensor) - and xt.shard_count == self.rotary_embed_table.shard_count - ) - rotary_shards = [ - unbox_tensor(shard) for shard in self.rotary_embed_table.shards - ] - - xt_shards = [ - self.forward_unsharded( - xt=unbox_tensor(xt_shard), - start_index=start_index, - rotary_embed_table=rotary_shard, - ) - for xt_shard, rotary_shard in zip(xt.shards, rotary_shards) - ] - xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) - return xt - else: + table = self.rotary_embed_table + if not isinstance(xt, SplitPrimitiveTensor): return self.forward_unsharded( xt=xt, start_index=start_index, - rotary_embed_table=self.rotary_embed_table, + rotary_embed_table=table, + ) + + assert ( + isinstance(table, ReplicatedTensor) and xt.shard_count == table.shard_count + ) + rotary_shards = [unbox_tensor(shard) for shard in table.shards] + + xt_shards = [ + self.forward_unsharded( + xt=unbox_tensor(xt_shard), + start_index=start_index, + rotary_embed_table=rotary_shard, ) + for xt_shard, rotary_shard in zip(xt.shards, rotary_shards) + ] + xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) + return xt def _create_interleaved_tensor(_, dim): """Creates a tensor which indexes an tensor such that @@ -143,18 +127,17 @@ def forward_unsharded( # Offset the table based on starting position. if self.use_table: freqs_cis = rotary_embed_table[start_index : start_index + sl, :] - freqs_cis = freqs_cis[None, 0:sl, None, :] + freqs_cis = freqs_cis[0:sl, :] else: freqs_cis = torch.arange(sl, device=xt.device) + start_index - freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :] + freqs_cis = self._compute_rotary_embed_table(freqs_cis) assert ( - freqs_cis.shape[1] >= sl + freqs_cis.shape[0] >= sl ), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})" - xt_ = ops.view_as_complex(xt_) - xt_ = xt_ * freqs_cis - xt_out = ops.view_as_real(xt_) + freqs_cis = ops.repeat(freqs_cis[None, :, :], (xt_.shape[0], 1, 1)) + xt_out = kernels.apply_rotary_embedding(xt_.to(freqs_cis.dtype), freqs_cis) if self.use_hf: xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])] @@ -181,7 +164,7 @@ def compute_batch_mask( self.trace_tensor("rope.positions_seq", positions_seq) if self.use_table: - freqs_cis = self.rotary_embed_table[positions_seq] + freqs_cis = self.rotary_embed_table[positions_seq.flatten()] else: shape = positions_seq.shape if isinstance(positions_seq, ReplicatedTensor): @@ -192,11 +175,8 @@ def compute_batch_mask( freqs_cis = ReplicatedTensor(ts=ts) else: freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten()) - freqs_cis = freqs_cis.unflatten(0, shape) - # Unsqueeze a unit dim for attention heads. - broadcast_freqs_cis = freqs_cis.unsqueeze(2) - return broadcast_freqs_cis + return freqs_cis.unsqueeze(1) def apply_batched_mask( self, @@ -232,9 +212,7 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): if self.use_hf: xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])] - xt_ = ops.view_as_complex(xt) - xt_ = xt_ * mask - xt_out = ops.view_as_real(xt_) + xt_out = kernels.apply_rotary_embedding(xt.to(mask.dtype), mask) if self.use_hf: xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])] @@ -244,14 +222,10 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): def _compute_rotary_embed_table(self, t): dim = self.rope_dimension_count freqs = 1.0 / ( - self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + self.rope_freq_base ** ((torch.arange(0, dim) // 2).float() / dim * 2.0) ) freqs = torch.outer(t, freqs).float() - - cos = torch.cos(freqs) - sin = torch.sin(freqs) - complex = torch.complex(cos, sin) - return complex + return freqs def _create_rotary_embed_table(self): t = torch.arange(self.max_seqlen, device=self.device) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 0a9a6f1c3..6fef6704e 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -67,7 +67,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): super().__init__( theta, context_length=config.hp.context_length, - static_tables=config.static_tables, device=config.device, activation_dtype=config.activation_dtype, attention_dtype=config.attention_dtype, @@ -92,7 +91,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): max_seqlen=hp.context_length, device=self.device, use_hf=self.use_hf, - static_tables=config.static_tables, tensor_parallelism_size=config.tensor_parallelism_size, ), ) @@ -126,7 +124,7 @@ def prefill( tokens: Union[torch.Tensor, ReplicatedTensor], *, # [1, 1, batch_seq_len, batch_seq_len] - attention_mask: Union[torch.Tensor, ReplicatedTensor], + attention_mask: Optional[Union[torch.Tensor, ReplicatedTensor]], # [bs, batch_seq_len // block_seq_stride] seq_block_ids: Union[torch.Tensor, ReplicatedTensor], cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], diff --git a/sharktank/tests/kernels/rotary.py b/sharktank/tests/kernels/rotary.py new file mode 100644 index 000000000..6c3d032a3 --- /dev/null +++ b/sharktank/tests/kernels/rotary.py @@ -0,0 +1,31 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import torch +import unittest + +from sharktank import kernels +from sharktank import ops + + +class rotary_test(unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + def test_rotary(self): + dtype = torch.float32 + a = torch.rand([1, 128, 1, 64], dtype=dtype) + rot = torch.rand([128, 32], dtype=dtype) + res_b = ops.view_as_real(torch.complex(rot, rot)) + ref_b = torch.complex(torch.cos(rot), torch.sin(rot)) + + result = kernels.apply_rotary_embedding(a, res_b) + ref = ops.view_as_real(ops.view_as_complex(a) * ref_b[None, :, None, :]) + torch.testing.assert_close(result, ref) From 277618662a515f80f537d9b058b8ff1f15ca4ec0 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Fri, 20 Dec 2024 17:08:25 -0500 Subject: [PATCH 34/39] Fix gibberish token problem for prefill also (#723) missed a line in #665 --- shortfin/python/shortfin_apps/llm/components/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 9ef8b5c4d..c8d6c51a8 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -383,7 +383,7 @@ async def run(self): if self.phase == InferencePhase.PREFILL: seq_lens_host = seq_lens.for_transfer() with seq_lens_host.map(discard=True) as m: - m.fill(0) + m.fill(1) m.items = [len(req.input_token_ids) for req in self.exec_requests] seq_lens_host.copy_to(seq_lens) From f5e9cb4ef3fb37c31b438ea6a88b9c8179b7e9e7 Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Mon, 23 Dec 2024 11:10:23 -0600 Subject: [PATCH 35/39] SGLang doc user flow updates (#703) - Expand examples in user docs to cover all supported features - Default docs to targeting cluster application - General cleanup and removal of unnecessary text - Add k8s instructions for shortfin deployment --------- Co-authored-by: saienduri Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com> Co-authored-by: Scott Todd --- .pre-commit-config.yaml | 1 + .../{llama_end_to_end.md => llama_serving.md} | 29 ++ .../llm/user/llama_serving_on_kubernetes.md | 44 +++ .../shortfin_with_sglang_frontend_language.md | 278 +++++++++--------- .../llm/k8s/llama-app-deployment.yaml | 59 ++++ 5 files changed, 276 insertions(+), 135 deletions(-) rename docs/shortfin/llm/user/{llama_end_to_end.md => llama_serving.md} (67%) create mode 100644 docs/shortfin/llm/user/llama_serving_on_kubernetes.md create mode 100644 shortfin/deployment/shortfin_apps/llm/k8s/llama-app-deployment.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9f2be5cf6..e2e8d797f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,6 +7,7 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml + args: ['--allow-multiple-documents'] - id: check-added-large-files - repo: https://github.com/psf/black rev: 22.10.0 diff --git a/docs/shortfin/llm/user/llama_end_to_end.md b/docs/shortfin/llm/user/llama_serving.md similarity index 67% rename from docs/shortfin/llm/user/llama_end_to_end.md rename to docs/shortfin/llm/user/llama_serving.md index a74851407..cc2c959b4 100644 --- a/docs/shortfin/llm/user/llama_end_to_end.md +++ b/docs/shortfin/llm/user/llama_serving.md @@ -272,3 +272,32 @@ If you want to find the process again: ```bash ps -f | grep shortfin ``` + +## Server Options + +To run the server with different options, you can use the +following command to see the available flags: + +```bash +python -m shortfin_apps.llm.server --help +``` + +### Server Options + +A full list of options can be found below: + +| Argument | Description | +| ----------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `--host HOST` | Specify the host to bind the server. | +| `--port PORT` | Specify the port to bind the server. | +| `--root-path ROOT_PATH` | Root path to use for installing behind a path-based proxy. | +| `--timeout-keep-alive TIMEOUT_KEEP_ALIVE` | Keep-alive timeout duration. | +| `--tokenizer_json TOKENIZER_JSON` | Path to a `tokenizer.json` file. | +| `--tokenizer_config_json TOKENIZER_CONFIG_JSON` | Path to a `tokenizer_config.json` file. | +| `--model_config MODEL_CONFIG` | Path to the model config file. | +| `--vmfb VMFB` | Model [VMFB](https://iree.dev/developers/general/developer-tips/#inspecting-vmfb-files) to load. | +| `--parameters [FILE ...]` | Parameter archives to load (supports: `gguf`, `irpa`, `safetensors`). | +| `--device {local-task,hip,amdgpu}` | Device to serve on (e.g., `local-task`, `hip`). Same options as [iree-run-module --list_drivers](https://iree.dev/guides/deployment-configurations/gpu-rocm/#get-the-iree-runtime). | +| `--device_ids [DEVICE_IDS ...]` | Device IDs visible to the system builder. Defaults to None (full visibility). Can be an index or a device ID like `amdgpu:0:0@0`. | +| `--isolation {none,per_fiber,per_call}` | Concurrency control: How to isolate programs. | +| `--amdgpu_async_allocations` | Enable asynchronous allocations for AMD GPU device contexts. | diff --git a/docs/shortfin/llm/user/llama_serving_on_kubernetes.md b/docs/shortfin/llm/user/llama_serving_on_kubernetes.md new file mode 100644 index 000000000..f573bd8ae --- /dev/null +++ b/docs/shortfin/llm/user/llama_serving_on_kubernetes.md @@ -0,0 +1,44 @@ +# Llama 8b GPU instructions on Kubernetes + +## Setup + +We will use an example with `llama_8b_f16` in order to describe the +process of exporting a model and deploying four instances of a shortfin llm server +behind a load balancer on MI300X GPU. + +### Pre-Requisites + +- Kubernetes cluster available to use +- kubectl installed on system and configured for cluster of interest + - To install kubectl, please check out [kubectl install](https://kubernetes.io/docs/tasks/tools/#kubectl) + and make sure to set the `KUBECONFIG` environment variable to point to your kube config file to authorize + connection to the cluster. + +### Deploy shortfin llama app service + +To generate the artifacts required for this k8s deployment, please follow [llama_serving.md](./llama_serving.md) until you have have all of the files that we need to run the shortfin LLM server. +Please upload your artifacts to a storage option that you can pull from in your k8s cluster (NFS, S3, CSP). +Save [llama-app-deployment.yaml](../../../../shortfin/deployment/shortfin_apps/llm/k8s/llama-app-deployment.yaml) locally and edit it to include your artifacts you just stored and change flags to intended configuration. + +To deploy llama app: + +``` +kubectl apply -f llama-app-deployment.yaml +``` + +To retrieve external IP for targetting the llama app load balancer: + +``` +kubectl get service shark-llama-app-service +``` + +Now, you can use the external IP for sglang integration or just sending text generation requests. + +### Delete shortfin llama app service + +After done using, make sure to delete: + +``` +kubectl delete deployment shark-llama-app-deployment +kubectl delete service shark-llama-app-service +``` diff --git a/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md index 1292eba66..3812b5277 100644 --- a/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md +++ b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md @@ -4,15 +4,15 @@ This doc includes basic steps for hooking up sglang with a running Shortfin serv ## Current Support Status -| Feature | Description | Enabled | Reference | -| ----------- | ----------- | ---------- | ------------ | -| `gen` | Generate shortfin completion, given a prompt | ✅ | [Shortfin Implementation](https://github.com/nod-ai/sglang/blob/main/python/sglang/lang/backend/shortfin.py) | -| `streaming` | Stream shortfin completion, given a prompt | ✅ | [Streaming](https://sgl-project.github.io/frontend/frontend.html#streaming) | -| `run_batch` | Run batch of disjoint requests with continous batching | ✅ | [Batching](https://sgl-project.github.io/frontend/frontend.html#batching) | -| `fork` | Generate sections of the same prompt in parallel | ✅ | [Fork Docs](https://sgl-project.github.io/frontend/frontend.html#parallelism) | -| `choices` | Given set of choices, generate response based on best log probs | ❌ | [Choices Methods](https://sgl-project.github.io/frontend/choices_methods.html#choices-methods-in-sglang) | -| `image` | Pass image as part of multi-modal prompt | ❌ | [sgl.image](https://sgl-project.github.io/frontend/frontend.html#multi-modality) | -| `regex` | Specify regular expression as decoding constraint | ❌ | [Regex](https://sgl-project.github.io/frontend/frontend.html#constrained-decoding) | +| Feature | Description | Enabled | Reference | +| ----------- | --------------------------------------------------------------- | ------- | ------------------------------------------------------------------------------------------------------------ | +| `gen` | Generate shortfin completion, given a prompt | ✅ | [Shortfin Implementation](https://github.com/nod-ai/sglang/blob/main/python/sglang/lang/backend/shortfin.py) | +| `streaming` | Stream shortfin completion, given a prompt | ✅ | [Streaming](https://sgl-project.github.io/frontend/frontend.html#streaming) | +| `run_batch` | Run batch of disjoint requests with continous batching | ✅ | [Batching](https://sgl-project.github.io/frontend/frontend.html#batching) | +| `fork` | Generate sections of the same prompt in parallel | ✅ | [Fork Docs](https://sgl-project.github.io/frontend/frontend.html#parallelism) | +| `choices` | Given set of choices, generate response based on best log probs | ❌ | [Choices Methods](https://sgl-project.github.io/frontend/choices_methods.html#choices-methods-in-sglang) | +| `image` | Pass image as part of multi-modal prompt | ❌ | [sgl.image](https://sgl-project.github.io/frontend/frontend.html#multi-modality) | +| `regex` | Specify regular expression as decoding constraint | ❌ | [Regex](https://sgl-project.github.io/frontend/frontend.html#constrained-decoding) | ## Prerequisites @@ -24,20 +24,22 @@ For this tutorial, you will need to meet the following prerequisites: - You can check out [pyenv](https://github.com/pyenv/pyenv) as a good tool to be able to manage multiple versions of python on the same system. -- A running `shortfin` LLM server as described [below](#installstart-shortfin-llm-server) + +### Shortfin LLM Server + +- A running `shortfin` LLM server. Directions on launching the llm server on one system can be found in [Llama end to end serving instructions](./llama_serving.md) and for launching +on a kubernetes cluster, see [Llama 8b GPU instructions on Kubernetes](./llama_serving_on_kubernetes.md) - We will use the shortfin server as the `backend` to generate completions from SGLang's `frontend language`. In this tutorial, you can think of `sglang` as the client and `shortfin` as the server. -### Hardware - -- This tutorial is designed to run on an [AMD MI300X GPU](https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html) - -## Install/Start `shortfin` LLM server +After the `shortfin` LLM Server has started, we must obtain the base_url. +We will store this in our environment in order to send request to `shortfin` + through the `sglang` client examples below. -Follow the steps [here](https://github.com/nod-ai/shark-ai/blob/main/docs/shortfin/llm/user/llama_end_to_end.md) -to export a model with `sharktank` and start a `shortfin` LLM server -with that model. +```bash +export SHORTFIN_BASE_URL="SHORTFIN_BASE_URL" # example: http://localhost:8000 +``` ## Install sglang @@ -48,6 +50,8 @@ We can use pip to install it in the same virtual environment that we used to start our Shortfin LLM Server. ```bash +python -m venv --prompt shark-ai .venv +source .venv/bin/activate pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" ``` @@ -56,8 +60,23 @@ pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" You can verify the installation/setup through the following examples: - [Multi-Turn Q&A Example](#multi-turn-qa-example) +- [Streaming Example](#streaming-example) - [Fork Example](#fork-example) -- [Benchmark Shortfin](#bench-mark-shortfin-w-sglang-bench_serving-script) +- [Multi-Turn Q&A Batching Example](#multi-turn-qa-batch-example) + +In these examples, we will set our `max_tokens` to 50 when generating completions. +This details how many tokens we want to generate for each completion. + +We can modify the arguments passed to `sgl.gen` to alter the outputs of our +`shortfin` LLM server. Specifically: + +- `max_tokens` - The maximum number of tokens to generate for completion. + We may obtain longer responses by increasing this value, + and shorter responses by decreasing it. +- `temperature` - We can include a temperature parameter to control the + randomness of the generated completions. A higher value + will result in more randomness, while a lower value will + result in more deterministic completions. ## Multi-Turn Q&A example @@ -75,20 +94,24 @@ python You can copy and paste the following example into your interpreter: ```python +import os + import sglang as sgl from sglang.lang.chat_template import get_chat_template -backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000", ) # Change base_url if running at different address +SHORTFIN_BASE_URL = os.environ["SHORTFIN_BASE_URL"] + +backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url=SHORTFIN_BASE_URL) sgl.set_default_backend(backend) @sgl.function def multi_turn_question(s, question_1, question_2): s += sgl.user(question_1) - s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=50)) s += sgl.user(question_2) - s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=50)) state = multi_turn_question.run(question_1="Name the capital city of the USA.", question_2="The Smithsonian is in this location.") @@ -96,40 +119,64 @@ for m in state.messages(): print(m["role"], m["content"]) ``` -### Shortfin example output +## Streaming Example -You should see an output similar to this: +We can stream our request for a more responsive feel. Let's invoke a `streaming` Q&A from our server: -```text -========== single ========== +```python +import os -user : Name the capital city of the USA -assistant : The capital city of the United States of America is Washington, D.C. (short for District of Columbia). -user : The Smithsonian is in this location. -assistant : The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes. -``` +import sglang as sgl +from sglang.lang.chat_template import get_chat_template -## Fork example +SHORTFIN_BASE_URL = os.environ["SHORTFIN_BASE_URL"] -Now that we have sglang installed, we can run an example to show a `fork` -flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html): +backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url=SHORTFIN_BASE_URL) -### Open python interpreter +sgl.set_default_backend(backend) -```bash -python +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=50)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=50)) + +question_1 = "Name the capital city of the USA." +question_2 = "The Smithsonian is in this location." + +# Run the multi-turn question function with streaming enabled +state = multi_turn_question.run( + question_1=question_1, + question_2=question_2, + stream=True, +) + +# Collect messages from the streamed output +messages = "" + +for chunk in state.text_iter(): + messages += chunk + +print(messages) ``` -### Run example -You can copy and paste the following example into your interpreter: +## Fork example + +We can also send different pieces of the same prompt in parallel using the `fork` +flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html): ```python +import os + import sglang as sgl from sglang.lang.chat_template import get_chat_template -backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000") # Change base_url if running at different address +SHORTFIN_BASE_URL = os.environ["SHORTFIN_BASE_URL"] + +backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url=SHORTFIN_BASE_URL) sgl.set_default_backend(backend) @@ -142,7 +189,7 @@ def tip_suggestion(s): forks = s.fork(2) for i, f in enumerate(forks): f += f"Now, expand tip {i+1} into a paragraph:\n" - f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n") + f += sgl.gen(f"detailed_tip", max_tokens=50, stop="\n\n") s += "Tip 1:" + forks[0]["detailed_tip"] + "\n" s += "Tip 2:" + forks[1]["detailed_tip"] + "\n" s += "In summary" + sgl.gen("summary") @@ -152,103 +199,64 @@ state = tip_suggestion.run() print(state.text()) ``` -### Shortfin example output - -You should see an output similar to this: - -```text -Here are two tips for staying healthy: 1. Balanced Diet. 2. Regular Exercise. - -Tip 1:A balanced diet is important for maintaining good health. It should -include a variety of foods from all the major food groups, such as fruits, -vegetables, grains, proteins, and dairy. Eating a balanced diet can help -prevent chronic diseases such as heart disease, diabetes, and obesity. - -Now, expand tip 2 into a paragraph: -Regular exercise is also important for maintaining good health. It can help -improve cardiovascular health, strengthen muscles and bones, and reduce the -risk of chronic diseases. Exercise can also help improve mental health by -reducing stress and anxiety. It is recommended that adults get at least 150 -minutes of moderate-intensity exercise or 75 minutes of vigorous-intensity -exercise per week. - -Now, combine the two paragraphs into a single paragraph: -A balanced diet and regular exercise are both important for maintaining good -health. A balanced diet should include a variety of foods from all the major -food groups, such as fruits, vegetables, grains, proteins, and dairy. -Eating a balanced diet can help prevent chronic diseases such as heart disease, -diabetes, and obesity. Regular exercise is also important for maintaining good -health. It can help improve cardiovascular health, strengthen muscles and bones, -and reduce the risk of chronic diseases. Exercise can also help improve mental -health by reducing stress and anxiety. It is recommended that - -Tip 2:Regular exercise is important for maintaining a healthy body and mind. -It can help improve cardiovascular health, strengthen muscles and bones, -and reduce the risk of chronic diseases such as diabetes and heart disease. -Additionally, exercise has been shown to improve mood, reduce stress, -and increase overall well-being. It is recommended that adults engage in -at least 150 minutes of moderate-intensity aerobic activity or 75 minutes of -vigorous-intensity aerobic activity per week, as well as strength training -exercises at least two days per week. - -In summary, a balanced diet and regular exercise are both essential for -maintaining good health. A balanced diet should include a variety of foods from -all the major food groups, while regular exercise can help improve -cardiovascular health, strengthen muscles and bones, reduce the risk of -chronic diseases, and improve mental health. It is recommended that adults -engage in at least 150 minutes of moderate-intensity aerobic activity or -75 minutes of vigorous-intensity aerobic activity per week, -as well as strength training exercises at least two days per week. -``` +## Multi-Turn Q&A Batch Example -## Benchmark shortfin w/ sglang `bench_serving` script +With **Shortfin** + SGLang, we can also easily send requests as a batch. +Let's now invoke a `batched` Q&A flow with the SGLang [Batching](https://sgl-project.github.io/frontend/frontend.html#batching): -We can obtain benchmarking metrics using the `bench_serving` script -provided by SGLang: +```python +import os -**NOTE: Change `--base-url` if running at a different address** +import sglang as sgl +from sglang.lang.chat_template import get_chat_template -```bash -python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer /path/to/tokenizer/dir --request-rate 1 -``` +SHORTFIN_BASE_URL = os.environ["SHORTFIN_BASE_URL"] + +backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url=SHORTFIN_BASE_URL) + +# Set the default backend for sglang +sgl.set_default_backend(backend) + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=50)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=50)) + +# Define the questions for the first and second sets +question_1_1 = "Name the capital city of the USA." +question_1_2 = "The Smithsonian is in this location." +question_2_1 = "Name the largest city in the USA." +question_2_2 = "The Empire State Building is in this location." + +# Run the multi-turn question function in batch mode +states = multi_turn_question.run_batch( + [ + { + "question_1": question_1_1, + "question_2": question_1_2, + }, + { + "question_1": question_2_1, + "question_2": question_2_2, + }, + ] +) + +# Extract responses from the states +first_qa = states[0] +second_qa = states[1] + +first_qa_messages = first_qa.messages() +second_qa_messages = second_qa.messages() + +# Print messages from the first QA session +for m in first_qa_messages: + print(m["role"], m["content"]) + +# Print messages from the second QA session +for m in second_qa_messages: + print(m["role"], m["content"]) -There are some more metrics captured, but the most relevant are the following: - -- E2E Latency -- TTFT (Time to First Token) -- TPOT (Time per Output Token) -- ITL (Inter-Token Latency) -- Request Throughput -- Benchmark Duration - -When complete, you should see an output similar to this: - -```text -============ Serving Benchmark Result ============ -Backend: shortfin -Traffic request rate: 1.0 -Successful requests: 10 -Benchmark duration (s): 427.91 -Total input tokens: 1960 -Total generated tokens: 2774 -Total generated tokens (retokenized): 63 -Request throughput (req/s): 0.02 -Input token throughput (tok/s): 4.58 -Output token throughput (tok/s): 6.48 -----------------End-to-End Latency---------------- -Mean E2E Latency (ms): 416268.77 -Median E2E Latency (ms): 417159.14 ----------------Time to First Token---------------- -Mean TTFT (ms): 292404.29 -Median TTFT (ms): 365989.01 -P99 TTFT (ms): 367325.63 ------Time per Output Token (excl. 1st token)------ -Mean TPOT (ms): 1359.41 -Median TPOT (ms): 163.96 -P99 TPOT (ms): 6316.12 ----------------Inter-token Latency---------------- -Mean ITL (ms): 2238.99 -Median ITL (ms): 958.75 -P99 ITL (ms): 2719.50 -================================================== ``` diff --git a/shortfin/deployment/shortfin_apps/llm/k8s/llama-app-deployment.yaml b/shortfin/deployment/shortfin_apps/llm/k8s/llama-app-deployment.yaml new file mode 100644 index 000000000..08a22aa3d --- /dev/null +++ b/shortfin/deployment/shortfin_apps/llm/k8s/llama-app-deployment.yaml @@ -0,0 +1,59 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: shark-llama-app-deployment +spec: + replicas: 4 # number of server instances + selector: + matchLabels: + app: shark-llama-app + template: + metadata: + labels: + app: shark-llama-app + spec: + containers: + - name: shark-llama-app-container + image: rocm/dev-ubuntu-22.04:6.3 + command: ["/bin/bash", "-c"] + # update to artifacts you generated form llama_serving.md (this is an example with the base llama3.1 8b tp1 artifacts) + # change cli flags for instantiation of server to match your intended llama configuration + args: + - | + sudo apt update && + curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash && + sudo apt install git -y && + sudo apt install python3.11 python3.11-dev python3.11-venv -y && + sudo apt-get install wget -y && + python3.11 -m venv shark_venv && source shark_venv/bin/activate && + mkdir shark_artifacts && + wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b/config.json -O shark_artifacts/config.json && + wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b/meta-llama-3.1-8b-instruct.f16.gguf -O shark_artifacts/meta-llama-3.1-8b-instruct.f16.gguf && + wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b/model.vmfb -O shark_artifacts/model.vmfb && + wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b/tokenizer_config.json -O shark_artifacts/tokenizer_config.json && + wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b/tokenizer.json -O shark_artifacts/tokenizer.json && + pip install --pre shortfin[apps] -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels && + pip install pandas && + python -m shortfin_apps.llm.server --tokenizer_json=shark_artifacts/tokenizer.json --model_config=shark_artifacts/config.json --vmfb=shark_artifacts/model.vmfb --parameters=shark_artifacts/meta-llama-3.1-8b-instruct.f16.gguf --device=hip; + resources: + # change number of gpus required here based on your llama configuration + requests: + amd.com/gpu: 1 + limits: + amd.com/gpu: 1 + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + name: shark-llama-app-service +spec: + selector: + app: shark-llama-app + ports: + - protocol: TCP + port: 80 # external port + targetPort: 8000 # port the container exposes + type: LoadBalancer From ffb0dd2c6be7106e725d49a25b591a3e467913f3 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Tue, 31 Dec 2024 13:31:53 -0500 Subject: [PATCH 36/39] Fix FlatSymbolRefAttr (#732) FlatSymbolRefAttr creation was failing when exporting with `iree` at head due to LLVM changes. --- sharktank/sharktank/kernels/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sharktank/sharktank/kernels/base.py b/sharktank/sharktank/kernels/base.py index ce792b525..30791a828 100644 --- a/sharktank/sharktank/kernels/base.py +++ b/sharktank/sharktank/kernels/base.py @@ -102,9 +102,7 @@ def _get_jinja2_env() -> Environment: def call_function(target_function: Operation, *operands: Value) -> Sequence[Value]: - target_symbol = FlatSymbolRefAttr.get( - StringAttr(target_function.attributes["sym_name"]).value_bytes - ) + target_symbol = FlatSymbolRefAttr.get(target_function.attributes["sym_name"].value) ftype = FunctionType(TypeAttr(target_function.attributes["function_type"]).value) operands = [i for i in operands if i is not None] return Operation.create( From a1e632a280ae160f53ffff4ecb0c7ec87b83dddf Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 2 Jan 2025 09:12:50 -0800 Subject: [PATCH 37/39] Add CLI script exporting CLIP Toy model IREE test data (#672) This is required to have an easy way of exporting test data that will be used in IREE to guard against regressions. E.g. ``` python -m sharktank.models.clip.export_toy_text_model_iree_test_data \ --output-dir=clip_toy_text_model ``` Refactor some of the existing tests to reuse the new export logic. --- sharktank/sharktank/models/clip/export.py | 25 ++- .../export_toy_text_model_iree_test_data.py | 29 ++++ sharktank/sharktank/models/clip/testing.py | 161 +++++++++++++++++- sharktank/sharktank/utils/io.py | 25 ++- sharktank/tests/models/clip/clip_test.py | 145 ++++++---------- 5 files changed, 280 insertions(+), 105 deletions(-) create mode 100644 sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py index 3cae3f4c4..95dbdacad 100644 --- a/sharktank/sharktank/models/clip/export.py +++ b/sharktank/sharktank/models/clip/export.py @@ -11,8 +11,8 @@ CLIPEncoderLayer as HfCLIPEncoderLayer, CLIPEncoder as HfCLIPEncoder, ) -from os import PathLike import torch +from os import PathLike from ...types.theta import Theta, Dataset, torch_module_to_theta from ...layers.configs import ClipTextConfig @@ -50,9 +50,14 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset: return Dataset(properties=model.config.to_properties(), root_theta=model.theta) +def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike): + dataset = clip_text_model_to_dataset(model) + dataset.save(output_path) + + def export_clip_text_model_dataset_from_hugging_face( - model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel], - output_path: Union[str, PathLike], + model_or_name_or_path: Union[PathLike, transformers.CLIPTextModel], + output_path: PathLike, dtype: Optional[torch.dtype] = None, ): if isinstance(model_or_name_or_path, transformers.CLIPTextModel): @@ -99,3 +104,17 @@ def _( output = export(fxb, import_symbolic_shape_expressions=True) output.save_mlir(mlir_output_path) + + +def export_clip_text_model_to_iree( + model: ClipTextModel, + batch_sizes: list[int], + mlir_output_path: PathLike, + parameters_output_path: PathLike, +): + export_clip_text_model_iree_parameters(model, parameters_output_path) + export_clip_text_model_mlir( + model=parameters_output_path, + batch_sizes=batch_sizes, + mlir_output_path=mlir_output_path, + ) diff --git a/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py b/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py new file mode 100644 index 000000000..979bc3255 --- /dev/null +++ b/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py @@ -0,0 +1,29 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from argparse import ArgumentParser +from typing import Optional +from pathlib import Path + +from .testing import export_clip_toy_text_model_default_iree_test_data + + +def main(args: Optional[list[str]] = None): + parser = ArgumentParser( + description=( + "Export test data for toy-sized CLIP text model." + " This program MLIR, parameters sample input and expected output." + " Exports float32 and bfloat16 model variants." + " The expected output is always in float32 precision." + ) + ) + parser.add_argument("--output-dir", type=str, default=f"clip_toy_text_model") + args = parser.parse_args(args=args) + export_clip_toy_text_model_default_iree_test_data(Path(args.output_dir)) + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/models/clip/testing.py b/sharktank/sharktank/models/clip/testing.py index 87634c220..852da8a18 100644 --- a/sharktank/sharktank/models/clip/testing.py +++ b/sharktank/sharktank/models/clip/testing.py @@ -4,14 +4,167 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ...layers.configs.llm_configs import ClipTextConfig -from ...types.theta import Theta -from .export import hugging_face_clip_text_model_to_theta +import functools import torch +from os import PathLike, makedirs +from typing import Union, Optional +from copy import copy +from iree.turbine.aot.params import ParameterArchiveBuilder + +from ...layers.configs.llm_configs import ClipTextConfig +from .clip import ClipTextModel +from ...types.theta import Theta, Dataset +from ...types.tensors import dtype_to_serialized_short_name +from ...utils.io import save_tensor_as_irpa +from .export import ( + clip_text_model_to_dataset, + hugging_face_clip_text_model_to_theta, + export_clip_text_model_to_iree, +) +from ...transforms.dataset import set_float_dtype + + +def clip_toy_text_model_config(dtype: Optional[torch.dtype] = None) -> ClipTextConfig: + num_attention_heads = 5 + vocab_size = 11 + return ClipTextConfig( + vocab_size=vocab_size, + hidden_size=13 * num_attention_heads, + intermediate_size=7, + projection_dim=3, + num_attention_heads=num_attention_heads, + max_position_embeddings=17, + layer_norm_eps=1e-4, + num_hidden_layers=2, + bos_token_id=vocab_size - 2, + eos_token_id=vocab_size - 1, + dtype=dtype, + ) + + +def export_clip_toy_text_model_default_iree_test_data(output_dir: PathLike): + makedirs(output_dir, exist_ok=True) + + # We want to always export the same without interfering with RNG for the rest of + # the program. + rng_state = torch.get_rng_state() + torch.random.manual_seed(12345) + + reference_dtype = torch.float32 + target_dtypes = [torch.float32, torch.bfloat16] + target_iree_parameters_output_paths = [] + target_mlir_output_paths = [] + batch_size = 4 + for dtype in target_dtypes: + prefix = output_dir / f"{dtype_to_serialized_short_name(dtype)}" + target_iree_parameters_output_paths.append(f"{prefix}_parameters.irpa") + target_mlir_output_paths.append(f"{prefix}.mlir") + call_prefix = output_dir / f"forward_bs{batch_size}" + input_ids_output_path = f"{call_prefix}_arg0_input_ids.irpa" + expected_last_hidden_state_output_path = ( + f"{call_prefix}_expected_result0_last_hidden_state_" + f"{dtype_to_serialized_short_name(reference_dtype)}.irpa" + ) + export_clip_toy_text_model_iree_test_data( + reference_dtype=reference_dtype, + target_dtypes=target_dtypes, + batch_size=batch_size, + input_ids_output_path=input_ids_output_path, + expected_last_hidden_state_output_path=expected_last_hidden_state_output_path, + target_iree_parameters_output_paths=target_iree_parameters_output_paths, + target_mlir_output_paths=target_mlir_output_paths, + ) + + torch.set_rng_state(rng_state) + + +def export_clip_toy_text_model_iree_test_data( + reference_dtype: torch.dtype, + target_dtypes: list[torch.dtype], + batch_size: int, + target_iree_parameters_output_paths: list[PathLike], + target_mlir_output_paths: list[PathLike], + input_ids_output_path: PathLike, + expected_last_hidden_state_output_path: PathLike, +): + reference_config = clip_toy_text_model_config(reference_dtype) + input_ids = make_random_input_token_sequences( + batch_size=batch_size, config=reference_config + ) + reference_theta = make_clip_text_model_random_theta(reference_config) + reference_model = ClipTextModel(theta=reference_theta, config=reference_config) + for i, ( + target_dtype, + target_iree_parameters_output_path, + target_mlir_output_path, + ) in enumerate( + zip( + target_dtypes, + target_iree_parameters_output_paths, + target_mlir_output_paths, + strict=True, + ) + ): + current_input_ids_output_path = None + current_expected_last_hidden_state_output_path = None + if i == 0: + current_input_ids_output_path = input_ids_output_path + current_expected_last_hidden_state_output_path = ( + expected_last_hidden_state_output_path + ) + export_clip_text_model_iree_test_data( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + target_iree_parameters_output_path=target_iree_parameters_output_path, + target_mlir_output_path=target_mlir_output_path, + input_ids_output_path=current_input_ids_output_path, + expected_last_hidden_state_output_path=current_expected_last_hidden_state_output_path, + ) + + +def export_clip_text_model_iree_test_data( + reference_model: ClipTextModel, + target_dtype: torch.dtype, + input_ids: torch.LongTensor, + target_mlir_output_path: PathLike, + target_iree_parameters_output_path: PathLike, + input_ids_output_path: Optional[PathLike] = None, + expected_last_hidden_state_output_path: Optional[PathLike] = None, +): + batch_size = input_ids.shape[0] + reference_dataset = clip_text_model_to_dataset(reference_model) + target_config = copy(reference_model.config) + target_config.dtype = target_dtype + target_dataset = Dataset( + root_theta=reference_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=target_dtype) + ), + properties=target_config.to_properties(), + ) + target_model = ClipTextModel(theta=target_dataset.root_theta, config=target_config) + export_clip_text_model_to_iree( + target_model, + batch_sizes=[batch_size], + mlir_output_path=target_mlir_output_path, + parameters_output_path=target_iree_parameters_output_path, + ) + + if input_ids_output_path is not None: + save_tensor_as_irpa(input_ids, input_ids_output_path) + + if expected_last_hidden_state_output_path is None: + return + + expected_last_hidden_state = reference_model(input_ids=input_ids)[ + "last_hidden_state" + ] + save_tensor_as_irpa( + expected_last_hidden_state, expected_last_hidden_state_output_path + ) def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta: - from transformers import CLIPTextConfig as HfCLIPTextConfig from transformers import CLIPTextModel as HfCLIPTextModel hf_config = config.to_hugging_face_clip_text_model_config() diff --git a/sharktank/sharktank/utils/io.py b/sharktank/sharktank/utils/io.py index ac2480846..12c9c002b 100644 --- a/sharktank/sharktank/utils/io.py +++ b/sharktank/sharktank/utils/io.py @@ -5,10 +5,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from pathlib import Path +import torch +from os import PathLike -from iree.turbine.aot import ( - ParameterArchiveBuilder, -) +from iree.turbine.aot import ParameterArchiveBuilder, ParameterArchive class ShardedArchiveBuilder(ParameterArchiveBuilder): @@ -49,3 +49,22 @@ def path_for_rank(path: Path, rank: int): /tmp/foobar.rank0.irpa """ return path.with_suffix(f".rank{rank}{path.suffix}") + + +def save_tensor_as_irpa(tensor: torch.Tensor, path: PathLike): + """Save a single tensor into an IRPA file.""" + param_builder = ParameterArchiveBuilder() + param_builder.add_tensor("", tensor) + param_builder.save(path) + + +def load_irpa_as_tensor(tensor: torch.Tensor, path: PathLike, **kwargs): + """Load a tensor form an IRPA file that holds only one tensor.""" + params = ParameterArchive(path, **kwargs) + items = params.items() + if len(items) != 1: + raise ValueError( + f'Too many items {len(items)} in IRPA file "{path}".' + " Only a single tensor was expected." + ) + return items[0][1].as_tensor() diff --git a/sharktank/tests/models/clip/clip_test.py b/sharktank/tests/models/clip/clip_test.py index 99af4ba6f..704333c90 100644 --- a/sharktank/tests/models/clip/clip_test.py +++ b/sharktank/tests/models/clip/clip_test.py @@ -8,6 +8,7 @@ import functools import iree.compiler import os +from pathlib import Path from parameterized import parameterized from copy import copy import pytest @@ -47,18 +48,18 @@ test_prompts, ) from sharktank.models.clip.export import ( - export_clip_text_model_mlir, export_clip_text_model_dataset_from_hugging_face, hugging_face_clip_attention_to_theta, hugging_face_clip_encoder_layer_to_theta, hugging_face_clip_encoder_to_theta, hugging_face_clip_text_model_to_dataset, hugging_face_clip_text_model_to_theta, - clip_text_model_to_dataset, ) from sharktank.models.clip.testing import ( make_random_input_token_sequences, make_clip_text_model_random_theta, + export_clip_text_model_iree_test_data, + clip_toy_text_model_config, ) from sharktank.models.clip import ( ClipAttention, @@ -72,13 +73,15 @@ with_clip_data = pytest.mark.skipif("not config.getoption('with_clip_data')") -@pytest.mark.usefixtures("caching", "path_prefix") +@pytest.mark.usefixtures("path_prefix") class ClipTextIreeTest(TempDirTestBase): def setUp(self): super().setUp() torch.random.manual_seed(12345) if self.path_prefix is None: - self.path_prefix = f"{self._temp_dir}/" + self.path_prefix = self._temp_dir + else: + self.path_prefix = Path(self.path_prefix) @with_clip_data def testSmokeExportLargeF32FromHuggingFace(self): @@ -90,12 +93,20 @@ def testSmokeExportLargeF32FromHuggingFace(self): huggingface_repo_id, ).download() target_dtype_name = dtype_to_serialized_short_name(torch.float32) - target_model_path_prefix = f"{self.path_prefix}{huggingface_repo_id_as_path}_text_model_{target_dtype_name}" + target_model_path_prefix = ( + self.path_prefix + / f"{huggingface_repo_id_as_path}_text_model_{target_dtype_name}" + ) output_path = f"{target_model_path_prefix}.irpa" export_clip_text_model_dataset_from_hugging_face( huggingface_repo_id, output_path ) + def testSmokeExportToyIreeTestData(self): + from sharktank.models.clip.export_toy_text_model_iree_test_data import main + + main([f"--output-dir={self.path_prefix/'clip_toy_text_model'}"]) + @with_clip_data def testCompareLargeIreeF32AgainstTorchEagerF32(self): self.runTestCompareIreeAgainstPretrainedTorchEager( @@ -141,43 +152,31 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens( ) target_dtype_name = dtype_to_serialized_short_name(target_dtype) reference_model_path_prefix = ( - f"{self.path_prefix}{file_artifact_prefix_name}_{reference_dtype_name}" + self.path_prefix / f"{file_artifact_prefix_name}_{reference_dtype_name}" ) target_model_path_prefix = ( - f"{self.path_prefix}{file_artifact_prefix_name}_{target_dtype_name}" - ) - - target_config = copy(reference_model.config) - target_config.dtype = target_dtype - reference_dataset = clip_text_model_to_dataset(reference_model) - target_dataset = Dataset( - root_theta=reference_dataset.root_theta.transform( - functools.partial(set_float_dtype, dtype=target_config.dtype) - ), - properties=target_config.to_properties(), + self.path_prefix / f"{file_artifact_prefix_name}_{target_dtype_name}" ) parameters_path = f"{target_model_path_prefix}.irpa" - if not self.caching or not os.path.exists(parameters_path): - target_dataset.save(parameters_path) - - dataset = Dataset.load(parameters_path) - target_config = ClipTextConfig.from_properties(dataset.properties) input_args = OrderedDict([("input_ids", input_ids)]) batch_size = input_ids.shape[0] - mlir_path = f"{target_model_path_prefix}.mlir" - if not self.caching or not os.path.exists(mlir_path): - export_clip_text_model_mlir( - parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path - ) + + export_clip_text_model_iree_test_data( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + target_mlir_output_path=mlir_path, + target_iree_parameters_output_path=parameters_path, + ) + iree_module_path = f"{target_model_path_prefix}.vmfb" - if not self.caching or not os.path.exists(iree_module_path): - iree.compiler.compile_file( - mlir_path, - output_file=iree_module_path, - extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"], - ) + iree.compiler.compile_file( + mlir_path, + output_file=iree_module_path, + extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"], + ) reference_result_dict = call_torch_module_function( module=reference_model, @@ -211,11 +210,11 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens( for i in range(len(expected_outputs)) ] - actual_last_hidden_states = actual_outputs[0] - expected_last_hidden_states = expected_outputs[0] + actual_last_hidden_state = actual_outputs[0] + expected_last_hidden_state = expected_outputs[0] assert_text_encoder_state_close( - actual_last_hidden_states, expected_last_hidden_states, atol + actual_last_hidden_state, expected_last_hidden_state, atol ) def runTestCompareRandomModelIreeAgainstTorch( @@ -243,21 +242,7 @@ def runTestCompareToyModelIreeAgainstTorch( self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float ): batch_size = 4 - num_attention_heads = 5 - vocab_size = 11 - reference_config = ClipTextConfig( - vocab_size=vocab_size, - hidden_size=13 * num_attention_heads, - intermediate_size=7, - projection_dim=3, - num_attention_heads=num_attention_heads, - max_position_embeddings=17, - layer_norm_eps=1e-4, - num_hidden_layers=2, - bos_token_id=vocab_size - 2, - eos_token_id=vocab_size - 1, - dtype=reference_dtype, - ) + reference_config = clip_toy_text_model_config(reference_dtype) file_artifact_prefix_name = "clip_text_model_toy" self.runTestCompareRandomModelIreeAgainstTorch( reference_config=reference_config, @@ -404,21 +389,9 @@ def testCompareEagerToySizedModelAgainstTransformers( ): torch.set_default_dtype(reference_dtype) batch_size = 19 - tgt_len = 23 - num_attention_heads = 5 vocab_size = 11 - reference_config = transformers.CLIPTextConfig( - vocab_size=vocab_size, - hidden_size=13 * num_attention_heads, - intermediate_size=7, - projection_dim=3, - num_attention_heads=num_attention_heads, - layer_norm_eps=1e-4, - num_hidden_layers=2, - final_layer_norm=1e-3, - bos_token_id=vocab_size - 2, - eos_token_id=vocab_size - 1, - ) + config = clip_toy_text_model_config() + reference_config = config.to_hugging_face_clip_text_model_config() reference_model = HfCLIPTextModel( reference_config, ) @@ -432,7 +405,9 @@ def testCompareEagerToySizedModelAgainstTransformers( ) model = ClipTextModel(theta, config) - input_ids = torch.randint(low=0, high=vocab_size, size=[batch_size, tgt_len]) + input_ids = torch.randint( + low=0, high=vocab_size, size=[batch_size, config.max_position_embeddings] + ) expected_outputs = reference_model(input_ids=input_ids) @@ -471,16 +446,10 @@ def testCompareEagerToySizedModelAgainstTransformers( ): torch.set_default_dtype(reference_dtype) batch_size = 19 - tgt_len = 23 + config = clip_toy_text_model_config() + reference_config = config.to_hugging_face_clip_text_model_config() + tgt_len = config.max_position_embeddings src_len = tgt_len - num_attention_heads = 2 - reference_config = transformers.CLIPTextConfig( - vocab_size=11, - hidden_size=13 * num_attention_heads, - intermediate_size=7, - projection_dim=3, - num_attention_heads=num_attention_heads, - ) reference_model = HfCLIPAttention( reference_config, ) @@ -495,7 +464,7 @@ def testCompareEagerToySizedModelAgainstTransformers( model = ClipAttention(theta, config) reference_hidden_states = make_rand_torch( - shape=[batch_size, tgt_len, reference_config.hidden_size], + shape=[batch_size, tgt_len, config.hidden_size], dtype=reference_dtype, ) reference_attention_mask = make_random_mask( @@ -551,17 +520,10 @@ def testCompareEagerToySizedModelAgainstTransformers( ): torch.set_default_dtype(reference_dtype) batch_size = 19 - tgt_len = 23 + config = clip_toy_text_model_config() + reference_config = config.to_hugging_face_clip_text_model_config() + tgt_len = config.max_position_embeddings src_len = tgt_len - num_attention_heads = 2 - reference_config = transformers.CLIPTextConfig( - vocab_size=11, - hidden_size=13 * num_attention_heads, - intermediate_size=7, - projection_dim=3, - num_attention_heads=num_attention_heads, - layer_norm_eps=1e-4, - ) reference_model = HfCLIPEncoderLayer( reference_config, ) @@ -634,15 +596,8 @@ def testCompareEagerToySizedModelAgainstTransformers( batch_size = 19 tgt_len = 23 src_len = tgt_len - num_attention_heads = 5 - reference_config = transformers.CLIPTextConfig( - vocab_size=11, - hidden_size=13 * num_attention_heads, - intermediate_size=7, - projection_dim=3, - num_attention_heads=num_attention_heads, - layer_norm_eps=1e-4, - num_hidden_layers=2, + reference_config = ( + clip_toy_text_model_config().to_hugging_face_clip_text_model_config() ) reference_model = HfCLIPEncoder( reference_config, From 56f3d217e871389c9cd11481b8a4316a82c90353 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 2 Jan 2025 09:38:43 -0800 Subject: [PATCH 38/39] Add export of Flux-dev transformer and uploading to Azure (#717) There was a slight difference in the Flux schnell and dev variants. Namely, dev has a guidance layer and schenll does not. Fixed some tensor argument element value types as they were always passed as f32 while some of them should use the model's dtype. Refactored a bit the Flux transformer export boilerplate. Added a script that uploads models to Azure. Right now it uploads the Flux transformer models. This can become a part of the CI jobs at some point. --- sharktank/pyproject.toml | 1 + sharktank/requirements-dev.txt | 4 + sharktank/sharktank/models/flux/export.py | 66 ++++- sharktank/sharktank/models/flux/flux.py | 27 +- sharktank/sharktank/models/flux/testing.py | 257 ++++++++++++++++++ .../tools/upload_all_models_to_azure.py | 57 ++++ sharktank/sharktank/types/theta.py | 20 +- sharktank/sharktank/utils/azure.py | 127 +++++++++ sharktank/sharktank/utils/hf_datasets.py | 15 + sharktank/tests/models/flux/flux_test.py | 53 +--- 10 files changed, 580 insertions(+), 47 deletions(-) create mode 100644 sharktank/requirements-dev.txt create mode 100644 sharktank/sharktank/models/flux/testing.py create mode 100644 sharktank/sharktank/tools/upload_all_models_to_azure.py create mode 100644 sharktank/sharktank/utils/azure.py diff --git a/sharktank/pyproject.toml b/sharktank/pyproject.toml index 01cad409b..09aca178b 100644 --- a/sharktank/pyproject.toml +++ b/sharktank/pyproject.toml @@ -34,6 +34,7 @@ sharktank = ["py.typed", "kernels/templates/*.mlir"] file = ["requirements.txt"] [tool.setuptools.dynamic.optional-dependencies] +dev = {file = ["requirements-dev.txt"]} testing = {file = ["requirements-tests.txt"]} [tool.pytest.ini_options] diff --git a/sharktank/requirements-dev.txt b/sharktank/requirements-dev.txt new file mode 100644 index 000000000..560aaa24c --- /dev/null +++ b/sharktank/requirements-dev.txt @@ -0,0 +1,4 @@ +# Dependencies only required during development. + +azure-identity>=1.19 +azure-storage-blob>=12.24 diff --git a/sharktank/sharktank/models/flux/export.py b/sharktank/sharktank/models/flux/export.py index fae3a5362..404f00413 100644 --- a/sharktank/sharktank/models/flux/export.py +++ b/sharktank/sharktank/models/flux/export.py @@ -5,6 +5,9 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from os import PathLike +import os +from pathlib import Path +import torch from ...export import export_static_model_mlir from ...tools.import_hf_dataset import import_hf_dataset @@ -12,7 +15,7 @@ from ...types import Dataset from ...utils.hf_datasets import get_dataset -flux_transformer_default_batch_sizes = [4] +flux_transformer_default_batch_sizes = [1] def export_flux_transformer_model_mlir( @@ -23,6 +26,31 @@ def export_flux_transformer_model_mlir( export_static_model_mlir(model, output_path=output_path, batch_sizes=batch_sizes) +def export_flux_transformer_iree_parameters( + model: FluxModelV1, parameters_output_path: PathLike +): + model.theta.rename_tensors_to_paths() + # TODO: export properties + dataset = Dataset(root_theta=model.theta, properties={}) + dataset.save(parameters_output_path) + + +def export_flux_transformer( + model: FluxModelV1, + mlir_output_path: PathLike, + parameters_output_path: PathLike, + batch_sizes: list[int] = flux_transformer_default_batch_sizes, +): + export_flux_transformer_iree_parameters(model, parameters_output_path) + + dataset = Dataset.load(parameters_output_path) + model_with_frozen_theta = FluxModelV1(theta=dataset.root_theta, params=model.params) + model_with_frozen_theta.theta = dataset.root_theta + export_flux_transformer_model_mlir( + model_with_frozen_theta, output_path=mlir_output_path, batch_sizes=batch_sizes + ) + + def export_flux_transformer_from_hugging_face( repo_id: str, mlir_output_path: PathLike, @@ -47,3 +75,39 @@ def export_flux_transformer_from_hugging_face( export_flux_transformer_model_mlir( model, output_path=mlir_output_path, batch_sizes=batch_sizes ) + + +def export_flux_transformer_models(dir: Path): + from .testing import export_dev_random_single_layer + + base_dir = dir / "flux" / "transformer" + os.makedirs(base_dir) + + file_name_base = "black-forest-labs--FLUX.1-dev--black-forest-labs-transformer-bf16" + mlir_path = base_dir / f"{file_name_base}.mlir" + parameters_output_path = base_dir / f"{file_name_base}.irpa" + export_flux_transformer_from_hugging_face( + "black-forest-labs/FLUX.1-dev/black-forest-labs-transformer", + mlir_output_path=mlir_path, + parameters_output_path=parameters_output_path, + ) + + file_name_base = ( + "black-forest-labs--FLUX.1-schnell--black-forest-labs-transformer-bf16" + ) + mlir_path = base_dir / f"{file_name_base}.mlir" + parameters_output_path = base_dir / f"{file_name_base}.irpa" + export_flux_transformer_from_hugging_face( + "black-forest-labs/FLUX.1-schnell/black-forest-labs-transformer", + mlir_output_path=mlir_path, + parameters_output_path=parameters_output_path, + ) + + file_name_base = "black-forest-labs--FLUX.1-dev--transformer-single-layer-b16" + mlir_path = base_dir / f"{file_name_base}.mlir" + parameters_output_path = base_dir / f"{file_name_base}.irpa" + export_dev_random_single_layer( + dtype=torch.bfloat16, + mlir_output_path=mlir_path, + parameters_output_path=parameters_output_path, + ) diff --git a/sharktank/sharktank/models/flux/flux.py b/sharktank/sharktank/models/flux/flux.py index d99b14ad4..531083ae1 100644 --- a/sharktank/sharktank/models/flux/flux.py +++ b/sharktank/sharktank/models/flux/flux.py @@ -11,6 +11,7 @@ from typing import Any, Optional from collections import OrderedDict +from copy import copy import math from dataclasses import dataclass import torch @@ -96,6 +97,7 @@ def __init__(self, theta: Theta, params: FluxParams): theta, ) + self.params = copy(params) self.in_channels = params.in_channels self.out_channels = self.in_channels if params.hidden_size % params.num_heads != 0: @@ -146,6 +148,8 @@ def __init__(self, theta: Theta, params: FluxParams): LastLayer(theta("final_layer")), ) + self.dtype = self._deduce_dtype() + def forward( self, img: AnyTensor, @@ -193,12 +197,12 @@ def sample_inputs( raise ValueError(f'Only function "forward" is supported. Got "{function}"') # TODO: do not hardcode these but derive the required shapes from the config. - img = torch.rand([batch_size, 1024, 64]) - img_ids = torch.rand([batch_size, 1024, 3]) - txt = torch.rand([batch_size, 512, 4096]) - txt_ids = torch.rand([batch_size, 512, 3]) - timesteps = torch.rand([batch_size]) - y = torch.rand([batch_size, 768]) + img = torch.rand([batch_size, 1024, 64], dtype=self.dtype) + img_ids = torch.rand([batch_size, 1024, 3], dtype=torch.float32) + txt = torch.rand([batch_size, 512, 4096], dtype=self.dtype) + txt_ids = torch.rand([batch_size, 512, 3], dtype=torch.float32) + timesteps = torch.rand([batch_size], dtype=self.dtype) + y = torch.rand([batch_size, 768], dtype=self.dtype) args = tuple() kwargs = OrderedDict( @@ -211,8 +215,19 @@ def sample_inputs( ("y", y), ) ) + + if self.guidance: + kwargs["guidance"] = torch.rand([batch_size], dtype=self.dtype) + return args, kwargs + def _deduce_dtype(self) -> torch.dtype: + dtype = self.theta("img_in.weight").dtype + assert ( + dtype == self.theta("time_in.in_layer.weight").dtype + ), "Inconsistent dtype" + return dtype + ################################################################################ # Layers diff --git a/sharktank/sharktank/models/flux/testing.py b/sharktank/sharktank/models/flux/testing.py new file mode 100644 index 000000000..e0354ff7b --- /dev/null +++ b/sharktank/sharktank/models/flux/testing.py @@ -0,0 +1,257 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +from os import PathLike + +from .flux import FluxParams, FluxModelV1 +from .export import export_flux_transformer, flux_transformer_default_batch_sizes +from ...types import DefaultPrimitiveTensor, Theta, save_load_theta +from ...layers.testing import ( + make_rand_torch, +) + + +def make_random_theta(config: FluxParams, dtype: torch.dtype): + # TODO: do not hardcode values. + + in_channels = config.in_channels + in_channels2 = 128 + hidden_size = config.hidden_size + mlp_ratio = config.mlp_ratio + mlp_hidden_size = int((mlp_ratio - 1) * hidden_size) + mlp_hidden_size2 = int(mlp_ratio * hidden_size) + mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size) + mlp_hidden_size4 = int((mlp_ratio + 1) * hidden_size) + mlp_hidden_size5 = int((2 * mlp_ratio - 1) * hidden_size) + context_in_dim = config.context_in_dim + time_dim = 256 + vec_dim = config.vec_in_dim + patch_size = 1 + out_channels = config.out_channels + tensor_dict = { + "img_in.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, in_channels), dtype=dtype) + ), + "img_in.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "txt_in.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, context_in_dim), dtype=dtype) + ), + "txt_in.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "time_in.in_layer.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, time_dim), dtype=dtype) + ), + "time_in.in_layer.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "time_in.out_layer.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "time_in.out_layer.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "vector_in.in_layer.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, vec_dim), dtype=dtype) + ), + "vector_in.in_layer.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "vector_in.out_layer.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "vector_in.out_layer.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "double_blocks.0.img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.img_attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "double_blocks.0.img_attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.img_attn.qkv.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) + ), + "double_blocks.0.img_attn.qkv.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.img_mlp.0.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.img_mlp.0.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype) + ), + "double_blocks.0.img_mlp.2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size), dtype=dtype) + ), + "double_blocks.0.img_mlp.2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.img_mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3,), dtype=dtype) + ), + "double_blocks.0.img_mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.txt_attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "double_blocks.0.txt_attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_attn.qkv.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) + ), + "double_blocks.0.txt_attn.qkv.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_mlp.0.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.txt_mlp.0.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_mlp.2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_mlp.2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.txt_mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3,), dtype=dtype) + ), + "double_blocks.0.txt_mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) + ), + "single_blocks.0.norm.key_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "single_blocks.0.norm.query_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "single_blocks.0.attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "single_blocks.0.attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "single_blocks.0.linear1.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size5,), dtype=dtype) + ), + "single_blocks.0.linear1.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size5, hidden_size), dtype=dtype) + ), + "single_blocks.0.linear2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size), dtype=dtype) + ), + "single_blocks.0.linear2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, mlp_hidden_size4), dtype=dtype) + ), + "single_blocks.0.modulation.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) + ), + "single_blocks.0.modulation.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) + ), + "final_layer.linear.weight": DefaultPrimitiveTensor( # + data=make_rand_torch( + (patch_size * patch_size * out_channels, hidden_size), dtype=dtype + ) + ), + "final_layer.linear.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((patch_size * patch_size * out_channels,), dtype=dtype) + ), + "final_layer.adaLN_modulation.1.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size * 2, hidden_size), dtype=dtype) + ), + "final_layer.adaLN_modulation.1.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size * 2,), dtype=dtype) + ), + } + + if config.guidance_embed: + tensor_dict["guidance_in.in_layer.weight"] = DefaultPrimitiveTensor( # + data=make_rand_torch( + ( + hidden_size, + time_dim, + ), + dtype=dtype, + ) + ) + tensor_dict["guidance_in.in_layer.bias"] = DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ) + tensor_dict["guidance_in.out_layer.weight"] = DefaultPrimitiveTensor( # + data=make_rand_torch( + ( + hidden_size, + hidden_size, + ), + dtype=dtype, + ) + ) + tensor_dict["guidance_in.out_layer.bias"] = DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ) + + return Theta(tensor_dict) + + +def export_dev_random_single_layer( + dtype: torch.dtype, + mlir_output_path: PathLike, + parameters_output_path: PathLike, + batch_sizes: list[int] = flux_transformer_default_batch_sizes, +): + rng_state = torch.get_rng_state() + torch.random.manual_seed(12345) + + dtype = torch.bfloat16 + params = FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=1, + depth_single_blocks=1, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ) + theta = make_random_theta(params, dtype) + flux = FluxModelV1( + theta=theta, + params=params, + ) + + export_flux_transformer( + flux, + mlir_output_path=mlir_output_path, + parameters_output_path=parameters_output_path, + batch_sizes=batch_sizes, + ) + + torch.set_rng_state(rng_state) diff --git a/sharktank/sharktank/tools/upload_all_models_to_azure.py b/sharktank/sharktank/tools/upload_all_models_to_azure.py new file mode 100644 index 000000000..05ce4e51a --- /dev/null +++ b/sharktank/sharktank/tools/upload_all_models_to_azure.py @@ -0,0 +1,57 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..utils.azure import upload_all_models + +import logging +import argparse + + +def main(args: list[str] = None): + parser = argparse.ArgumentParser( + description=( + "Upload all models to Azure storage. Uploads only if files are different. " + "If they need updating a snapshot will be created before uploading." + ) + ) + parser.add_argument( + "--account-name", type=str, required=True, help="Storage account name." + ) + parser.add_argument("--container-name", type=str, required=True) + parser.add_argument( + "--account-key", + type=str, + default=None, + help=( + "Access key. If not provided, will use environment variable AZURE_STORAGE_KEY" + " as key. If this is not available, will use the default Azure credential." + ), + ) + parser.add_argument( + "--destination-name-prefix", + type=str, + required=True, + help="Name prefix of all blobs that will be uploaded.", + ) + parsed_args = parser.parse_args(args) + + upload_all_models( + account_name=parsed_args.account_name, + container_name=parsed_args.container_name, + destination_name_prefix=parsed_args.destination_name_prefix, + account_key=parsed_args.account_key, + ) + + +if __name__ == "__main__": + # Set the logging level for all azure-storage-* libraries + azure_logger = logging.getLogger("azure.storage") + azure_logger.setLevel(logging.INFO) + + upload_logger = logging.getLogger("sharktank.utils.azure") + upload_logger.setLevel(logging.INFO) + + main() diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 021925169..29b870782 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import Any, Callable, Optional, Union, Collection, Sequence, List - +from tempfile import TemporaryFile import json from pathlib import Path from types import NotImplementedType @@ -32,7 +32,13 @@ REGISTERED_INFERENCE_TENSOR_CLASSES, ) -__all__ = ["Dataset", "flat_to_nested_dict", "Theta", "torch_module_to_theta"] +__all__ = [ + "Dataset", + "flat_to_nested_dict", + "Theta", + "torch_module_to_theta", + "save_load_theta", +] IOReportCallback = Callable[[str], None] @@ -297,6 +303,16 @@ def _norm_name_path(name_parts) -> list[str]: return accum +def save_load_theta(theta: Theta) -> Theta: + """Roundtrip to disk to avoid treating parameters as constants that would appear + in the MLIR.""" + theta.rename_tensors_to_paths() + dataset = Dataset(root_theta=theta, properties={}) + with TemporaryFile(prefix="save_load_theta", suffix=".irpa") as file_path: + dataset.save(file_path) + return Dataset.load(file_path).root_theta + + ################################################################################ # Dataset objects # diff --git a/sharktank/sharktank/utils/azure.py b/sharktank/sharktank/utils/azure.py new file mode 100644 index 000000000..b696f2b0b --- /dev/null +++ b/sharktank/sharktank/utils/azure.py @@ -0,0 +1,127 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient, ContentSettings +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Callable, Optional +import hashlib +import os +import logging + +logger = logging.getLogger(__name__) + + +def calculate_hash(file_path: str) -> str: + hasher = hashlib.md5() + with open(file_path, "rb") as file: + buf = file.read() + hasher.update(buf) + return hasher.digest() + + +def create_blob_service_client( + account_name: str, account_key: Optional[str] = None +) -> BlobServiceClient: + if account_key is None and "AZURE_STORAGE_KEY" in os.environ: + account_key = os.environ["AZURE_STORAGE_KEY"] + if account_key: + connection_string = ( + f"DefaultEndpointsProtocol=https;AccountName={account_name};" + f"AccountKey={account_key};" + "EndpointSuffix=core.windows.net" + ) + return BlobServiceClient.from_connection_string(connection_string) + + credential = DefaultAzureCredential() + account_url = f"https://{account_name}.blob.core.windows.net" + return BlobServiceClient(account_url, credential) + + +def snapshot_and_upload_blob_if_different( + blob_service_client: BlobServiceClient, + container_name: str, + blob_name: str, + file_path: str, +): + blob_client = blob_service_client.get_blob_client(container_name, blob_name) + local_hash = calculate_hash(file_path) + + blob_exists = False + try: + blob_properties = blob_client.get_blob_properties() + existing_hash = blob_properties.content_settings.content_md5 + blob_exists = True + except Exception: + existing_hash = None + + if local_hash == existing_hash: + logger.info(f'Skipping upload to blob "{blob_name}".') + return + + if blob_exists: + blob_client.create_snapshot() + + with open(file_path, "rb") as f: + logger.info(f'Uploading to blob "{blob_name}"...') + content_settings = ContentSettings(content_md5=local_hash) + blob_client.upload_blob(f, overwrite=True, content_settings=content_settings) + logger.info(f'Blob "{blob_name}" uploaded.') + + +def upload_directory( + blob_service_client: BlobServiceClient, + container_name: str, + source_dir: str, + destination_blob_name_prefix: str, +): + for root, dirs, files in os.walk(source_dir): + for file_name in files: + file_path = Path(root) / file_name + blob_name = f"{destination_blob_name_prefix}{os.path.relpath(file_path, source_dir)}" + snapshot_and_upload_blob_if_different( + blob_service_client, container_name, blob_name, file_path + ) + + +def upload_model( + export_fn: Callable[[Path], None], + blob_service_client: BlobServiceClient, + container_name: str, + destination_blob_name_prefix: str, +): + with TemporaryDirectory() as tmp_dir: + export_fn(Path(tmp_dir)) + upload_directory( + blob_service_client, + container_name, + source_dir=tmp_dir, + destination_blob_name_prefix=destination_blob_name_prefix, + ) + + +def upload_all_models( + account_name: str, + container_name: str, + destination_name_prefix: str, + account_key: Optional[str] = None, +): + """Upload all models to Azure. + Will generate temporary export artifacts. + If MD5 hashes match with the existing blobs nothing will be uploaded. + Creates snapshots if files need updating.""" + from ..models.flux.export import export_flux_transformer_models + + blob_service_client = create_blob_service_client(account_name, account_key) + + upload_model( + export_flux_transformer_models, + blob_service_client, + container_name, + destination_name_prefix, + ) + # TODO: add more models here diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py index 6893b637a..bc3a08c50 100644 --- a/sharktank/sharktank/utils/hf_datasets.py +++ b/sharktank/sharktank/utils/hf_datasets.py @@ -419,6 +419,21 @@ def alias_dataset(from_name: str, to_name: str): ), ), ) +Dataset( + "black-forest-labs/FLUX.1-dev/black-forest-labs-transformer", + ( + RemoteFile( + "config", + "black-forest-labs/FLUX.1-dev", + "transformer/config.json", + ), + RemoteFile( + "parameters", + "black-forest-labs/FLUX.1-dev", + "flux1-dev.safetensors", + ), + ), +) ################################################################################ diff --git a/sharktank/tests/models/flux/flux_test.py b/sharktank/tests/models/flux/flux_test.py index fc4d23251..ee8e6d82e 100644 --- a/sharktank/tests/models/flux/flux_test.py +++ b/sharktank/tests/models/flux/flux_test.py @@ -13,9 +13,9 @@ FluxParams, ) from sharktank.models.flux.export import ( - export_flux_transformer_model_mlir, export_flux_transformer_from_hugging_face, ) +from sharktank.models.flux.testing import export_dev_random_single_layer import sharktank.ops as ops from sharktank.layers.testing import ( make_rand_torch, @@ -216,52 +216,29 @@ def setUp(self): self.num_heads = 24 self.batch_size = 5 - def testExportBfloat16SingleLayer(self): - dtype = torch.bfloat16 - params = FluxParams( - in_channels=64, - out_channels=64, - vec_in_dim=768, - context_in_dim=4096, - hidden_size=3072, - mlp_ratio=4.0, - num_heads=24, - depth=1, - depth_single_blocks=1, - axes_dim=[16, 56, 56], - theta=10_000, - qkv_bias=True, - guidance_embed=False, - ) - theta = make_random_theta(dtype) - theta = self.save_load_theta(theta) - flux = FluxModelV1( - theta=theta, - params=params, - ) - - export_flux_transformer_model_mlir( - flux, - output_path=self._temp_dir / "model.mlir", - batch_sizes=[self.batch_size], + def testExportDevRandomSingleLayerBf16(self): + export_dev_random_single_layer( + dtype=torch.bfloat16, + batch_sizes=[1], + mlir_output_path=self._temp_dir / "model.mlir", + parameters_output_path=self._temp_dir / "parameters.irpa", ) @with_flux_data - def testExportSchnellFromHuggingFace(self): + def testExportSchnellTransformerFromHuggingFace(self): export_flux_transformer_from_hugging_face( "black-forest-labs/FLUX.1-schnell/black-forest-labs-transformer", mlir_output_path=self._temp_dir / "model.mlir", parameters_output_path=self._temp_dir / "parameters.irpa", ) - def save_load_theta(self, theta: Theta): - # Roundtrip to disk to avoid treating parameters as constants that would appear - # in the MLIR. - theta.rename_tensors_to_paths() - dataset = Dataset(root_theta=theta, properties={}) - file_path = self._temp_dir / "parameters.irpa" - dataset.save(file_path) - return Dataset.load(file_path).root_theta + @with_flux_data + def testExportDevTransformerFromHuggingFace(self): + export_flux_transformer_from_hugging_face( + "black-forest-labs/FLUX.1-dev/black-forest-labs-transformer", + mlir_output_path=self._temp_dir / "model.mlir", + parameters_output_path=self._temp_dir / "parameters.irpa", + ) if __name__ == "__main__": From e98e4586058735eb16e2c9392f1c498e393a6c3a Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 2 Jan 2025 15:40:09 -0800 Subject: [PATCH 39/39] Fix signed extension in q4_1 sharktank kernel (#726) --- .../templates/mmt_block_scaled_offset_q4_unsigned.mlir | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir index afe2928c0..00b98cf3f 100644 --- a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir +++ b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir @@ -98,12 +98,14 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_ ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type) outs(%result_fill : !accum_tensor_type) { ^bb0(%a_element: !a_type, %b_element: !a_type, %out: !accum_type): - %bmm_mul = arith.mulf %a_element, %b_element : !a_type {% if accum_type == a_type %} + %bmm_mul = arith.mulf %a_element, %b_element : !a_type %bmm_accum = arith.addf %bmm_mul, %out : !a_type {% else %} - %bmm_mul_ext = arith.extf %bmm_mul : !a_type to !accum_type - %bmm_accum = arith.addf %bmm_mul_ext, %out : !accum_type + %a_ext = arith.extf %a_element : !a_type to !accum_type + %b_ext = arith.extf %b_element : !a_type to !accum_type + %bmm_mul = arith.mulf %a_ext, %b_ext : !accum_type + %bmm_accum = arith.addf %bmm_mul, %out : !accum_type {% endif %} linalg.yield %bmm_accum : !accum_type } -> !accum_tensor_type