diff --git a/include/bh_python/register_axis.hpp b/include/bh_python/register_axis.hpp index 6f6c264e..da8276e5 100644 --- a/include/bh_python/register_axis.hpp +++ b/include/bh_python/register_axis.hpp @@ -47,6 +47,27 @@ auto vectorize_index(T input) { #define BHP_NOEXCEPT_17 #endif +namespace detail { +template +decltype(auto) axis_cast(py::handle x) { + return special_cast(x); +} + +template <> +inline decltype(auto) axis_cast(py::handle x) { + if(py::isinstance(x)) + return py::cast(x); + + auto val = py::cast(x); + auto ival = static_cast(val); + + if(static_cast(ival) == val) + return ival; + + throw py::type_error(py::str("cannot cast {} to int").format(val)); +} +} // namespace detail + // we overload vectorize index for category axis template auto vectorize_index(int (bh::axis::category::*pindex)(const T&) @@ -56,7 +77,7 @@ auto vectorize_index(int (bh::axis::category::*pindex)(c auto index = std::mem_fn(pindex); if(detail::is_value(arg)) { - auto index_value = index(self, detail::special_cast(arg)); + auto index_value = index(self, detail::axis_cast(arg)); if(index_value >= self.size()) throw pybind11::key_error(py::str("{!r} not in axis").format(arg)); return py::cast(index_value);