Skip to content

Commit

Permalink
Serialize MLIR type when instrumenting reference interpreter values (#…
Browse files Browse the repository at this point in the history
…1828)

Recently, we introduced a way to extract intermediate tensor state from
the StableHLO reference interpreter (#1784) for
instrumentation/debugging purposes. As part of this instrumentation
process, the interpreter will create an `index.csv` metadata file which
contains all serialized tensor paths and uniquely identifying `probe_id`
values in the form of:

```
probe_id,/some/absolute/path/to/numpy_0.npy
...
```

In the event that an `interpreter.probe` instruction is executed more
than once (i.e. due to being within a loop), there could be `N` entries
with the same `probe_id` value.

This current schema however does not serialize the `mlir::TensorType` of
the data. The inclusion of this type information can make it easier for
post processing tools/scripts to better interpret the metadata file,
without needing to load the accompanying `npy` file into memory to
extract size/type information. With these changes, the serialized
metadata file format now becomes:

```
probe_id,tensor<T>,/some/absolute/path/to/numpy_0.npy
...
```

Where `tensor<T>` is the type string produced by `mlir::debugString`
(i.e. `tensor<1x2xf32>`, etc). Additionally, the `expect_serialized_eq`
check dialect operation can now perform a stricter check, by also
locking down the serialized type vs the expected type.
  • Loading branch information
penagos authored Nov 30, 2023
1 parent 021e197 commit 97ad2b1
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 25 deletions.
8 changes: 5 additions & 3 deletions stablehlo/reference/InterpreterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Support/LLVM.h"
#include "stablehlo/reference/NumPy.h"
#include "stablehlo/reference/Ops.h"
Expand All @@ -44,7 +45,7 @@ namespace {

// Appends a new line item to an instrumentation metadata file, `index.json` in
// the form: `probeId,probeOutputDir/filename`.
llvm::Error writeProbeMetadata(StringRef probeId, StringRef filename,
llvm::Error writeProbeMetadata(StringRef probeId, Type type, StringRef filename,
StringRef probeOutputDir) {
if (probeOutputDir.empty())
return createStringError(llvm::errc::invalid_argument,
Expand All @@ -61,7 +62,8 @@ llvm::Error writeProbeMetadata(StringRef probeId, StringRef filename,
"Failed to open instrumentation metadata file.");

llvm::raw_fd_ostream out(fd, /*shouldClose=*/true);
out << probeId.str() << ',' << filename.str() << '\n';
out << probeId.str() << ',' << debugString(type) << ',' << filename.str()
<< '\n';

return llvm::Error::success();
}
Expand Down Expand Up @@ -212,7 +214,7 @@ llvm::Error evalProbeOp(InterpreterValue input, StringRef probeId,
// After the tensor has been serialized to disk, append it to a metadata file
// to associate the serialized probe_id with the filepath. By default, this
// will live in an `index.csv` file generated in specified `probeOutputDir`.
return writeProbeMetadata(probeId, filepath, probeOutputDir);
return writeProbeMetadata(probeId, input.getType(), filepath, probeOutputDir);
}

} // namespace interpreter
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/reference/InterpreterOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def Interpreter_ProbeOp : Op<Interpreter_Dialect, "probe",
file format. Writes tensor input value to
`<output-dir>/<probe_id>_<iteration>.npy` (where output-dir is specified by
the `--probe_output_dir` flag). Additionally, adds an entry to
<output-dir>/index.csv metadata file which maps probe IDs and iterations to
<output-dir>/index.csv metadata file which maps probe IDs, types and
filenames with their tensor values.

The `probe` operation will not modify its input in any way. Probe
Expand Down
63 changes: 45 additions & 18 deletions stablehlo/tests/CheckOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ limitations under the License.
namespace mlir {
namespace stablehlo {
namespace check {
namespace {

using SerializedTensorMetadata =
std::pair</*type=*/std::string, /*path=*/std::string>;

llvm::ErrorOr<SerializedTensorMetadata> extractMetadata(StringRef line) {
// Parse a CSV record in the form of: probe_id,mlir_type,serialized_path
constexpr int kNumFields = 3;
SmallVector<StringRef, kNumFields> fields;
line.split(fields, ',', kNumFields);

if (fields.size() != 3) return llvm::errc::invalid_argument;

return std::make_pair(/*type=*/fields[1].str(), /*path=*/fields[2].str());
}
} // namespace

//===----------------------------------------------------------------------===//
// Check Dialect Constructor
Expand Down Expand Up @@ -81,12 +97,11 @@ llvm::Error evalExpectEqOp(const Tensor &lhs, const Tensor &rhs) {
return llvm::Error::success();
}

// Fetch a previously serialized filepath given a `probeId` and a `probeDir` for
// a specified `iteration` value from an `index.csv` metadata file. If no data
// is found, returns an error.
static llvm::ErrorOr<std::string> getSerializedTensorPath(StringRef probeId,
StringRef probeDir,
uint32_t iteration) {
// Fetch a previously serialized MLIR type and data filepath given a `probeId`
// and a `probeDir` for a specified `iteration` value from an `index.csv`
// metadata file. If no data is found, returns an error.
static llvm::ErrorOr<SerializedTensorMetadata> getSerializedTensorMetadata(
StringRef probeId, StringRef probeDir, uint32_t iteration) {
if (probeDir.empty()) return llvm::errc::invalid_argument;

llvm::SmallString<128> instrumentationMetadataFile(probeDir);
Expand All @@ -103,29 +118,41 @@ static llvm::ErrorOr<std::string> getSerializedTensorPath(StringRef probeId,
auto pos = line.find(probeId);

if (pos != std::string::npos && match == iteration)
return line.substr(pos + probeId.size() + 1);
return extractMetadata(line);
}

return llvm::errc::bad_address;
}

llvm::Error evalExpectSerializedEqOp(const Tensor &expected, StringRef probeId,
StringRef probeDir, uint32_t iteration) {
auto serializedFilePathOrError =
getSerializedTensorPath(probeId, probeDir, iteration);
auto serializedMetadata =
getSerializedTensorMetadata(probeId, probeDir, iteration);

if (!serializedMetadata)
return llvm::createStringError(
serializedMetadata.getError(),
"Failed to find serialized data for probe %s.", probeId.str().c_str());

const std::string type = serializedMetadata->first;
const std::string serializedPath = serializedMetadata->second;

if (!serializedFilePathOrError)
return llvm::createStringError(serializedFilePathOrError.getError(),
"Failed to find serialized data for probe");
auto tensor = numpy::deserializeTensor(serializedPath, expected.getType());

auto tensorOrError =
numpy::deserializeTensor(*serializedFilePathOrError, expected.getType());
if (!tensor)
return llvm::createStringError(tensor.getError(),
"Failed to verify serialized tensor %s.",
probeId.str().c_str());

if (!tensorOrError)
return llvm::createStringError(tensorOrError.getError(),
"Failed to verify serialized tensor.");
const std::string expectedType = debugString(expected.getType());
if (type != expectedType)
return llvm::createStringError(llvm::errc::invalid_argument,
"Serialized types don't match: %s (actual) "
"vs %s (expected) for probe %s.",
expectedType.c_str(), type.c_str(),
probeId.str().c_str());

return evalExpectEqOp(expected, *tensorOrError);
return evalExpectEqOp(expected, *tensor);
}

} // namespace check
Expand Down
6 changes: 3 additions & 3 deletions stablehlo/tests/CheckOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def CHECK_ExpectEqOp : Op<CHECK_Dialect, "expect_eq", [SameTypeOperands]> {
def CHECK_ExpectSerializedEqOp : Op<CHECK_Dialect, "expect_serialized_eq", []> {
let summary = [{Checks value of serialized tensor value.}];
let description = [{
Verifies that the value of the serialized tensor `probe_id` matches the
optionally specified input tensor at iteration `iteration`, using previously
serialized filepaths in `index.csv`.
Verifies that the value and type of the serialized tensor `probe_id` match
the optionally specified input tensor at iteration `iteration`, using
previously serialized filepaths in `index.csv`.

```mlir
check.expect_serialized_eq %arg0,
Expand Down

0 comments on commit 97ad2b1

Please sign in to comment.