Skip to content

Commit

Permalink
Allow unsafe probe_id values to be serialized properly to disk (#1840)
Browse files Browse the repository at this point in the history
This PR builds off of the recently submitted changes in #1784. Prior to
this change, the value of `probe_id` and an internal execution counter
is used to derive the filename for serializing probed tensor values
(i.e. `<probe_output_dir>/<probe_id>_<execution_count>.npy`). This
current design has the following disadvantage:

1. This can be unsafe as no filename sanitation is performed, relying on
the compiler (or other tooling) to produce friendly `probe_id` values.
Instead of having this implicit assumption on the instrumented StableHLO
program, this change uses internally generated serialization filenames
by using a strictly increasing counter.

This change proposes that the serialization filename now be derived as
follows:
```
<probe_output_dir>/probe<i>.npy
```

Where `i` is a uniquely increasing positive `int64_t`.
  • Loading branch information
penagos authored Nov 30, 2023
1 parent f2257d2 commit 021e197
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 11 deletions.
8 changes: 4 additions & 4 deletions stablehlo/reference/Api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class DefaultInterpreterFallback : public InterpreterFallback {
stablehlo::InterpreterValue(scope.findTensor(probeOp.getOperand()));
auto status = stablehlo::interpreter::evalProbeOp(
input, probeOp.getProbeId(), config.probeInstrumentationDir,
instrumentedTensors);
++serializedProbeFileId);
scope.add(probeOp.getResult(), input);
return wrapFallbackStatus(std::move(status), funcName,
"interpreter.probe");
Expand Down Expand Up @@ -96,9 +96,9 @@ class DefaultInterpreterFallback : public InterpreterFallback {
/// Interpreter configuration.
const InterpreterConfiguration &config;

/// If the input StableHLO program has been instrumented, keep track of how
/// many times a given operation has been executed.
llvm::StringMap<int32_t> instrumentedTensors;
/// Probe instrumentation counter for uniquely identifying instrumented tensor
/// filenames.
int64_t serializedProbeFileId = 0;
};

} // namespace
Expand Down
12 changes: 6 additions & 6 deletions stablehlo/reference/InterpreterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,17 @@ SmallVector<InterpreterValue> evalRunParallelOp(
return results;
}

// `serializedProbeFileId` should be a unique positive integer which can be used
// to unambiguously derive a serialized filename for a given `probeId`.
llvm::Error evalProbeOp(InterpreterValue input, StringRef probeId,
StringRef probeOutputDir,
llvm::StringMap<int32_t> &probeIterations) {
int64_t serializedProbeFileId) {
llvm::SmallString<128> filepath(probeOutputDir);

// To properly support loops, append a suffix denoting how many times this
// specific probe_id has executed.
const int32_t numTimesExecuted = ++probeIterations[probeId];

// Use an increasing unique integer to write to disk to avoid any odd file
// names as a result of unsafe probe_id values.
llvm::sys::path::append(
filepath, probeId + "_" + std::to_string(numTimesExecuted) + ".npy");
filepath, "probe" + std::to_string(serializedProbeFileId) + ".npy");
auto tensor = input.getTensor();
if (auto serializationResultError =
numpy::serializeTensor(filepath, tensor.getType(), tensor.getData()))
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/reference/InterpreterOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ SmallVector<InterpreterValue> evalRunParallelOp(

llvm::Error evalProbeOp(InterpreterValue input, StringRef probeId,
StringRef probeOutputDir,
llvm::StringMap<int32_t> &probeIterations);
int64_t serializedProbeFileId);

} // namespace interpreter
} // namespace stablehlo
Expand Down
11 changes: 11 additions & 0 deletions stablehlo/tests/interpret_probe.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ func.func @probe_c32() {

// -----

func.func @probe_sanitized_probe_id() {
%0 = stablehlo.constant dense<[[1], [2], [3]]> : tensor<3x1xi64>
%1 = stablehlo.constant dense<[[4], [5], [6]]> : tensor<3x1xi64>
%2 = stablehlo.add %0, %1 : tensor<3x1xi64>
%3 = interpreter.probe %2, probe_id = "probe/0" : tensor<3x1xi64>
check.expect_serialized_eq %3, probe_id = "probe/0" : tensor<3x1xi64>
func.return
}

// -----

func.func @probe_iterations() {
// int i = 0;
// int sum = 0;
Expand Down

0 comments on commit 021e197

Please sign in to comment.