Skip to content

Commit

Permalink
fix segfault issue #283 (#284)
Browse files Browse the repository at this point in the history
This closes #283

The `XlaSend` call requires `envpool` to make a copy of the `action` to
prevent `action` from being recycled by the XLA runtime before `envpool`
finishes using it. Originally, I used `cudaMemcpy` to make sure the copy
was finished synchronously. However, it seems to cause a problem with
issue #283.

Here, I replace the original `cudaMemcpy` call with the async version, and
an explicit `streamSynchronize`.

It is not clear how `cudaMemcpy` in the default stream in a custom call
interacts with the stream managed by pjrt. However, from the code
[here](https://github.com/tensorflow/tensorflow/blob/0d2d79e84c9bdf71c737ad17a7b1dc04d9efc24f/tensorflow/compiler/xla/g3doc/custom_call.md),
I can hypothesize that an explicit stream synchronization in the custom
call is safe.
  • Loading branch information
mavenlin authored Oct 26, 2023
1 parent a9d2ec9 commit a1249e0
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions envpool/core/xla.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ Array GpuBufferToArray(cudaStream_t stream, const void* buffer,
spec = spec.Batch(batch_size);
}
Array ret(spec);
cudaMemcpy(ret.Data(), buffer, ret.size * ret.element_size,
cudaMemcpyDeviceToHost);
cudaMemcpyAsync(ret.Data(), buffer, ret.size * ret.element_size,
cudaMemcpyDeviceToHost, stream);
return ret;
}

Expand Down Expand Up @@ -161,6 +161,7 @@ struct XlaSend {
...);
},
action_spec);
cudaStreamSynchronize(stream);
envpool->Send(action);
}
};
Expand Down

0 comments on commit a1249e0

Please sign in to comment.