Skip to content

Commit

Permalink
Fix support for unary operators
Browse files Browse the repository at this point in the history
  • Loading branch information
zwimer committed Oct 8, 2022
1 parent 46be8b1 commit 58f6d36
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 23 deletions.
67 changes: 44 additions & 23 deletions source/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,46 @@ using namespace fmt::literals;

namespace binder {

static std::map<string, string > const cpp_python_operator_map{
{"operator+", "__add__"}, //
{"operator-", "__sub__"}, //
{"operator*", "__mul__"}, //
{"operator/", "__div__"}, //

{"operator+=", "__iadd__"}, //
{"operator-=", "__isub__"}, //
{"operator*=", "__imul__"}, //
{"operator/=", "__idiv__"}, //

{"operator()", "__call__"}, //
{"operator==", "__eq__"}, //
{"operator!=", "__ne__"}, //
{"operator[]", "__getitem__"}, //
{"operator=", "assign"}, //
{"operator++", "plus_plus"}, //
{"operator--", "minus_minus"}, //

{"operator->", "arrow"}, //
};
// Map C++ operators to python operators
// For C++ operators that share a name, index the vector based on the number
// of arguments excluding 'this'. For example, unary+ is map["operator+"][0]
// For C++ operators that do not share a name, index = 0

// Return the python operator that maps to the C++ operator; returns "" if no mapping exists
string cpp_python_operator(const FunctionDecl & F) {
static std::map<string, vector<string>> const m {
{"operator+", {"__pos__", "__add__"}}, //
{"operator-", {"__neg__", "__sub__"}}, //
{"operator*", {"dereference", "__mul__"}}, //
{"operator/", {"__div__"}}, //

{"operator+=", {"__iadd__"}}, //
{"operator-=", {"__isub__"}}, //
{"operator*=", {"__imul__"}}, //
{"operator/=", {"__idiv__"}}, //

{"operator()", {"__call__"}}, //
{"operator==", {"__eq__"}}, //
{"operator!=", {"__ne__"}}, //
{"operator[]", {"__getitem__"}}, //
{"operator=", {"assign"}}, //
{"operator++", {"pre_increment", "pre_increment"}}, //
{"operator--", {"pre_decrement", "post_decrement"}}, //

{"operator->", {"arrow"}} //
};
const auto & found = m.find(F.getNameAsString());
if (found != m.end()) {
const auto & vec { found->second };
if (vec.size() == 1) { return vec[0]; }
const auto n = F.getNumParams();
if (vec.size() > n) {
return vec[n];
}
}
return {};
}


// Generate function argument list separate by comma: int, bool, std::string
string function_arguments(clang::FunctionDecl const *record)
Expand Down Expand Up @@ -198,7 +217,9 @@ string template_specialization(FunctionDecl const *F)
// generate string represetiong class name that could be used in python
string python_function_name(FunctionDecl const *F)
{
if( F->isOverloadedOperator() ) return cpp_python_operator_map.at(F->getNameAsString());
if( F->isOverloadedOperator() ) {
return cpp_python_operator(*F);
}
else {
// if( auto m = dyn_cast<CXXMethodDecl>(F) ) {
// }
Expand Down Expand Up @@ -468,7 +489,7 @@ bool is_bindable_raw(FunctionDecl const *F)

if( F->isOverloadedOperator() ) {
// outs() << "Operator: " << F->getNameAsString() << '\n';
if( !isa<CXXMethodDecl>(F) or !cpp_python_operator_map.count(F->getNameAsString()) ) return false;
if( !isa<CXXMethodDecl>(F) or (cpp_python_operator(*F).size() == 0) ) return false;
}

r &= F->getTemplatedKind() != FunctionDecl::TK_FunctionTemplate /*and !F->isOverloadedOperator()*/ and !isa<CXXConversionDecl>(F) and !F->isDeleted();
Expand Down
4 changes: 4 additions & 0 deletions test/T12.operator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

struct T
{
T &operator+() { return *this; }
T &operator-() { return *this; }
T &operator*() { return *this; }

T &operator+(int) { return *this; }
T &operator-(int) { return *this; }
T &operator*(int) { return *this; }
Expand Down
3 changes: 3 additions & 0 deletions test/T12.operator.ref
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ void bind_T12_operator(std::function< pybind11::module &(std::string const &name
{ // T file:T12.operator.hpp line:15
pybind11::class_<T, std::shared_ptr<T>> cl(M(""), "T", "");
cl.def( pybind11::init( [](){ return new T(); } ) );
cl.def("__pos__", (struct T & (T::*)()) &T::operator+, "C++: T::operator+() --> struct T &", pybind11::return_value_policy::automatic);
cl.def("__neg__", (struct T & (T::*)()) &T::operator-, "C++: T::operator-() --> struct T &", pybind11::return_value_policy::automatic);
cl.def("dereference", (struct T & (T::*)()) &T::operator*, "C++: T::operator*() --> struct T &", pybind11::return_value_policy::automatic);
cl.def("__add__", (struct T & (T::*)(int)) &T::operator+, "C++: T::operator+(int) --> struct T &", pybind11::return_value_policy::automatic, pybind11::arg(""));
cl.def("__sub__", (struct T & (T::*)(int)) &T::operator-, "C++: T::operator-(int) --> struct T &", pybind11::return_value_policy::automatic, pybind11::arg(""));
cl.def("__mul__", (struct T & (T::*)(int)) &T::operator*, "C++: T::operator*(int) --> struct T &", pybind11::return_value_policy::automatic, pybind11::arg(""));
Expand Down

0 comments on commit 58f6d36

Please sign in to comment.