Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertion Error on CustomCall with input/output aliasing #21312

Open
wsmoses opened this issue Jan 11, 2025 · 2 comments
Open

Assertion Error on CustomCall with input/output aliasing #21312

wsmoses opened this issue Jan 11, 2025 · 2 comments
Assignees

Comments

@wsmoses
Copy link
Contributor

wsmoses commented Jan 11, 2025

module {
  func.func private @main_hlo_call_205bd746421f59dc(%arg0: tensor<17xf64>, %arg1: tensor<16xf64>, %arg2: tensor<1x34x34xf64>) -> tensor<f64> {
    %c = stablehlo.constant dense<0> : tensor<i64>
    %c_0 = stablehlo.constant dense<20> : tensor<i64>
    %c_1 = stablehlo.constant dense<1> : tensor<i64>
    %0:2 = stablehlo.custom_call @enzymexla_compile_gpu(%arg2, %arg0) {api_version = 4 : i32, backend_config = {attr = "P0\C1P\B3q\00\00\000\C1P\B3q\00\00"}, output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 1, operand_tuple_indices = []>]} : (tensor<1x34x34xf64>, tensor<17xf64>) -> (tensor<1x34x34xf64>, tensor<17xf64>)
    %1 = stablehlo.slice %0#1 [0:1] : (tensor<17xf64>) -> tensor<1xf64>
    %2 = stablehlo.reshape %1 : (tensor<1xf64>) -> tensor<f64>
    return %2 : tensor<f64>
  }
  func.func @main(%arg0: tensor<17xf64>, %arg1: tensor<16xf64>, %arg2: tensor<34x34x1xf64>) -> tensor<f64> {
    %0 = stablehlo.transpose %arg2, dims = [2, 1, 0] : (tensor<34x34x1xf64>) -> tensor<1x34x34xf64>
    %1 = call @main_hlo_call_205bd746421f59dc(%arg0, %arg1, %0) : (tensor<17xf64>, tensor<16xf64>, tensor<1x34x34xf64>) -> tensor<f64>
    return %1 : tensor<f64>
  }
}

Compiling the preceeding mlir module creates an assertion error with the following stacktrace.

        xla::status_macros::MakeErrorStream::Impl::GetStatus()
        xla::ShapeVerifier::HandleCustomCall(xla::HloInstruction*)
        absl::lts_20230802::Status xla::HloInstruction::Visit<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*)

        absl::lts_20230802::Status xla::HloInstruction::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*, bool, bool, bool)
        absl::lts_20230802::Status xla::HloComputation::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*) const
        xla::HloVerifier::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&)
        absl::lts_20230802::StatusOr<bool> xla::HloPassPipeline::RunPassesInternal<xla::HloModule>(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&)
        xla::HloPassPipeline::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&)
        xla::HloPassInterface::Run(xla::HloModule*)
        xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*)
        xla::gpu::NVPTXCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*)
        xla::gpu::GpuCompiler::OptimizeHloModule(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&)
        xla::gpu::GpuCompiler::RunHloPasses(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&)
        xla::Service::BuildExecutable(xla::HloModuleProto const&, std::unique_ptr<xla::HloModuleConfig, std::default_delete<xla::HloModuleConfig> >, xla::Backend*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, bool)
        xla::LocalService::CompileExecutables(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&)
        xla::LocalClient::Compile(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&)
        xla::PjRtStreamExecutorClient::CompileInternal(xla::XlaComputation const&, std::vector<xla::Shape const*, std::allocator<xla::Shape const*> > const&, std::function<absl::lts_20230802::StatusOr<std::pair<std::vector<xla::Shape, std::allocator<xla::Shape> >, xla::Shape> > (xla::HloModule const&)>, xla::CompileOptions)
        xla::PjRtStreamExecutorClient::Compile(mlir::ModuleOp, xla::CompileOptions)
@wsmoses
Copy link
Contributor Author

wsmoses commented Jan 11, 2025

Here the problem is presumably that the return is never used, so as a result the layout of the return is chosen differently from the layout of the input

@NaiyerRizz
Copy link

Hi @wsmoses
Can you please provide an HLO (not MLIR) code without your own function that still fails to compile.
Thanks.

@NaiyerRizz NaiyerRizz self-assigned this Jan 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants