Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
mavenlin committed Oct 26, 2023
1 parent 47ad258 commit eac02cd
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 20 deletions.
20 changes: 2 additions & 18 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
push:
branches:
- main
- debug
tags:
- v*

Expand All @@ -14,7 +15,7 @@ jobs:
container: trinkle23897/envpool-release:2023-01-02-5f1a5fd
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -34,20 +35,3 @@ jobs:
with:
name: wheel
path: wheelhouse/

publish:
runs-on: ubuntu-latest
needs: [release]
steps:
- uses: actions/download-artifact@v3
with:
path: artifact
- name: Move files so the next action can find them
run: |
mkdir dist && mv artifact/wheel/* dist/
ls dist/
- name: Publish distribution to PyPI
if: startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
5 changes: 3 additions & 2 deletions envpool/core/xla.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ Array GpuBufferToArray(cudaStream_t stream, const void* buffer,
spec = spec.Batch(batch_size);
}
Array ret(spec);
cudaMemcpy(ret.Data(), buffer, ret.size * ret.element_size,
cudaMemcpyDeviceToHost);
cudaMemcpyAsync(ret.Data(), buffer, ret.size * ret.element_size,
cudaMemcpyDeviceToHost, stream);
return ret;
}

Expand Down Expand Up @@ -161,6 +161,7 @@ struct XlaSend {
...);
},
action_spec);
cudaStreamSynchronize(stream);
envpool->Send(action);
}
};
Expand Down
7 changes: 7 additions & 0 deletions envpool/core/xla_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <string>
#include <tuple>
#include <vector>
#include <cassert>

namespace py = pybind11;

Expand All @@ -36,13 +37,19 @@ static auto SpecToTuple(const Spec& spec) {

template <std::size_t N>
void ToArray(const void** raw, std::array<void*, N>* array) {
for (int j = 0; j < N; ++j) {
assert(raw[j] != nullptr);
}
int i = 0;
std::apply([&](auto&&... a) { ((a = const_cast<void*>(raw[i++])), ...); },
*array);
}

template <std::size_t N>
void ToArray(void** raw, std::array<void*, N>* array) {
for (int j = 0; j < N; ++j) {
assert(raw[j] != nullptr);
}
int i = 0;
std::apply([&](auto&&... a) { ((a = raw[i++]), ...); }, *array);
}
Expand Down

0 comments on commit eac02cd

Please sign in to comment.