diff --git a/envpool/core/xla.h b/envpool/core/xla.h index e70fbb98..f4d9afff 100644 --- a/envpool/core/xla.h +++ b/envpool/core/xla.h @@ -74,8 +74,8 @@ Array GpuBufferToArray(cudaStream_t stream, const void* buffer, spec = spec.Batch(batch_size); } Array ret(spec); - cudaMemcpy(ret.Data(), buffer, ret.size * ret.element_size, - cudaMemcpyDeviceToHost); + cudaMemcpyAsync(ret.Data(), buffer, ret.size * ret.element_size, + cudaMemcpyDeviceToHost, stream); return ret; } @@ -161,6 +161,7 @@ struct XlaSend { ...); }, action_spec); + cudaStreamSynchronize(stream); envpool->Send(action); } };