diff --git a/src/vma/proto/netlink_socket_mgr.h b/src/vma/proto/netlink_socket_mgr.h index 1e1c978c2..4b324fe86 100644 --- a/src/vma/proto/netlink_socket_mgr.h +++ b/src/vma/proto/netlink_socket_mgr.h @@ -34,6 +34,7 @@ #ifndef NETLINK_SOCKET_MGR_H #define NETLINK_SOCKET_MGR_H +#include #include #include #include @@ -48,6 +49,11 @@ #include #include #include +#include +#include +#include +#include +#include #include "utils/bullseye.h" #include "utils/lock_wrapper.h" @@ -96,19 +102,19 @@ class netlink_socket_mgr table_t m_tab; - virtual bool parse_enrty(nlmsghdr *nl_header, Type *p_val) = 0; + virtual bool parse_entry(struct nl_object * nl_obj, void *p_val_context) = 0; virtual void update_tbl(); virtual void print_val_tbl(); void build_request(struct nlmsghdr **nl_msg); bool query(struct nlmsghdr *&nl_msg, int &len); int recv_info(); - void parse_tbl(int len, int *p_ent_num = NULL); + void parse_tbl_from_latest_cache(struct nl_cache *cache_state); private: nl_data_t m_data_type; - int m_fd; // netlink socket to communicate with the kernel + nl_sock * m_sock; // netlink socket to communicate with the kernel uint32_t m_pid; // process pid uint32_t m_seq_num; // seq num of the netlink messages char m_msg_buf[MSG_BUFF_SIZE]; // we use this buffer for sending/receiving netlink messages @@ -130,15 +136,19 @@ netlink_socket_mgr ::netlink_socket_mgr(nl_data_t data_type) memset(m_msg_buf, 0, m_buff_size); // Create Socket + BULLSEYE_EXCLUDE_BLOCK_START - if ((m_fd = orig_os_api.socket(PF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE)) < 0) { + m_sock = nl_socket_alloc(); + if (m_sock == nullptr) { __log_err("NL socket Creation: "); return; } - if (orig_os_api.fcntl(m_fd, F_SETFD, FD_CLOEXEC) != 0) { - __log_warn("Fail in fctl, error = %d", errno); - } + if (nl_connect(m_sock, NETLINK_ROUTE) < 0) { + __log_err("NL socket Connection: "); + nl_socket_free(m_sock); + } + BULLSEYE_EXCLUDE_BLOCK_END __log_dbg("Done"); @@ -148,138 +158,41 @@ template netlink_socket_mgr ::~netlink_socket_mgr() { __log_dbg(""); - if (m_fd) { - orig_os_api.close(m_fd); - m_fd = -1; + if (m_sock != nullptr) { + nl_socket_free(m_sock); + m_sock = nullptr; } __log_dbg("Done"); } -// This function build Netlink request to retrieve data (Rule, Route) from kernel. -// Parameters : -// nl_msg : request to be returned +// Update data in a table template -void netlink_socket_mgr ::build_request(struct nlmsghdr **nl_msg) +void netlink_socket_mgr ::update_tbl() { - struct rtmsg *rt_msg; - - memset(m_msg_buf, 0, m_buff_size); - - // point the header and the msg structure pointers into the buffer - *nl_msg = (struct nlmsghdr *)m_msg_buf; - rt_msg = (struct rtmsg *)NLMSG_DATA(*nl_msg); + m_tab.entries_num = 0; - //Fill in the nlmsg header - (*nl_msg)->nlmsg_len = NLMSG_LENGTH(sizeof(struct rtmsg)); - (*nl_msg)->nlmsg_seq = m_seq_num++; - (*nl_msg)->nlmsg_pid = m_pid; - rt_msg->rtm_family = AF_INET; + struct nl_cache *cache_state = {0}; + int err = 0; + // cache allocation fetches the latest existing rules/routes if (m_data_type == RULE_DATA_TYPE) { - (*nl_msg)->nlmsg_type = RTM_GETRULE; + err = rtnl_rule_alloc_cache(m_sock, AF_INET, &cache_state); } else if (m_data_type == ROUTE_DATA_TYPE) { - (*nl_msg)->nlmsg_type = RTM_GETROUTE; + err = rtnl_route_alloc_cache(m_sock, AF_INET, 0, &cache_state); } - (*nl_msg)->nlmsg_flags = NLM_F_DUMP | NLM_F_REQUEST; - -} - -// Query built request and receive requested data (Rule, Route) -// Parameters: -// nl_msg : request that is built previously. -// len : length of received data. -template -bool netlink_socket_mgr ::query(struct nlmsghdr *&nl_msg, int &len) -{ - if(m_fd < 0) - return false; - - BULLSEYE_EXCLUDE_BLOCK_START - if(orig_os_api.send(m_fd, nl_msg, nl_msg->nlmsg_len, 0) < 0){ - __log_err("Write To Socket Failed...\n"); - return false; - } - if((len = recv_info()) < 0) { - __log_err("Read From Socket Failed...\n"); - return false; + if (err < 0) + { + throw_vma_exception("Failed to allocate route cache"); } - BULLSEYE_EXCLUDE_BLOCK_END - - return true; -} - -// Receive requested data and save it locally. -// Return length of received data. -template -int netlink_socket_mgr ::recv_info() -{ - struct nlmsghdr *nlHdr; - int readLen = 0, msgLen = 0; - - char *buf_ptr = m_msg_buf; - - do{ - //Receive response from the kernel - BULLSEYE_EXCLUDE_BLOCK_START - if((readLen = orig_os_api.recv(m_fd, buf_ptr, MSG_BUFF_SIZE - msgLen, 0)) < 0){ - __log_err("SOCK READ: "); - return -1; - } - - nlHdr = (struct nlmsghdr *)buf_ptr; - - //Check if the header is valid - if((NLMSG_OK(nlHdr, (u_int)readLen) == 0) || (nlHdr->nlmsg_type == NLMSG_ERROR)) - { - __log_err("Error in received packet, readLen = %d, msgLen = %d, type=%d, bufLen = %d", readLen, nlHdr->nlmsg_len, nlHdr->nlmsg_type, MSG_BUFF_SIZE); - if (nlHdr->nlmsg_len == MSG_BUFF_SIZE) { - __log_err("The buffer we pass to netlink is too small for reading the whole table"); - } - return -1; - } - BULLSEYE_EXCLUDE_BLOCK_END - - buf_ptr += readLen; - msgLen += readLen; - - //Check if the its the last message - if(nlHdr->nlmsg_type == NLMSG_DONE || - (nlHdr->nlmsg_flags & NLM_F_MULTI) == 0) { - break; - } - - } while((nlHdr->nlmsg_seq != m_seq_num) || (nlHdr->nlmsg_pid != m_pid)); - return msgLen; -} - -// Update data in a table -template -void netlink_socket_mgr ::update_tbl() -{ - struct nlmsghdr *nl_msg = NULL; - int counter = 0; - int len = 0; - - m_tab.entries_num = 0; - - // Build Netlink request to get route entry - build_request(&nl_msg); - - // Query built request and receive requested data - if (!query(nl_msg, len)) - return; // Parse received data in custom object (route_val) - parse_tbl(len, &counter); - - m_tab.entries_num = counter; - - if (counter >= MAX_TABLE_SIZE) { + parse_tbl_from_latest_cache(cache_state); + if (m_tab.entries_num >= MAX_TABLE_SIZE) { __log_warn("reached the maximum route table size"); } } @@ -289,20 +202,28 @@ void netlink_socket_mgr ::update_tbl() // len : length of received data. // p_ent_num : number of rows in received data. template -void netlink_socket_mgr ::parse_tbl(int len, int *p_ent_num) +void netlink_socket_mgr::parse_tbl_from_latest_cache(struct nl_cache *cache_state) { - struct nlmsghdr *nl_header; - int entry_cnt = 0; + uint16_t entry_cnt = 0; - nl_header = (struct nlmsghdr *) m_msg_buf; - for(;NLMSG_OK(nl_header, (u_int)len) && entry_cnt < MAX_TABLE_SIZE; nl_header = NLMSG_NEXT(nl_header, len)) + struct nl_iterator_context { - if (parse_enrty(nl_header, &m_tab.value[entry_cnt])) { - entry_cnt++; + Type * p_val_array; + uint16_t& entry_cnt; + netlink_socket_mgr * this_ptr; + } iterator_context = {m_tab.value, entry_cnt, this}; + + // a lambda can't be casted to a c-fptr with ref captures - so we provide context ourselves + nl_cache_foreach(cache_state, [](struct nl_object * nl_obj, void *context) { + nl_iterator_context* operation_context = reinterpret_cast(context); + const bool is_valid_entry = operation_context->this_ptr->parse_entry(nl_obj, operation_context->p_val_array + operation_context->entry_cnt); + if (is_valid_entry) + { + ++operation_context->entry_cnt; } - } - if (p_ent_num) - *p_ent_num = entry_cnt; + }, &iterator_context); + + m_tab.entries_num = entry_cnt; } //print the table diff --git a/src/vma/proto/route_table_mgr.cpp b/src/vma/proto/route_table_mgr.cpp index 597d6343a..f8f045cdf 100644 --- a/src/vma/proto/route_table_mgr.cpp +++ b/src/vma/proto/route_table_mgr.cpp @@ -234,88 +234,84 @@ void route_table_mgr::rt_mgr_update_source_ip() } } -bool route_table_mgr::parse_enrty(nlmsghdr *nl_header, route_val *p_val) +bool route_table_mgr::parse_entry(struct nl_object * nl_obj, void *p_val_context) { - int len; - struct rtmsg *rt_msg; - struct rtattr *rt_attribute; - - // get route entry header - rt_msg = (struct rtmsg *) NLMSG_DATA(nl_header); + route_val* p_val = static_cast(p_val_context); + // Cast the generic nl_object to a specific route or rule object + struct rtnl_route *route = reinterpret_cast(nl_obj); // we are not concerned about the local and default route table - if (rt_msg->rtm_family != AF_INET || rt_msg->rtm_table == RT_TABLE_LOCAL) - return false; - - p_val->set_protocol(rt_msg->rtm_protocol); - p_val->set_scope(rt_msg->rtm_scope); - p_val->set_type(rt_msg->rtm_type); - p_val->set_table_id(rt_msg->rtm_table); - - in_addr_t dst_mask = htonl(VMA_NETMASK(rt_msg->rtm_dst_len)); - p_val->set_dst_mask(dst_mask); - p_val->set_dst_pref_len(rt_msg->rtm_dst_len); + if (rtnl_route_get_family(route) != AF_INET || rtnl_route_get_table(route) == RT_TABLE_LOCAL) + return false; + + // Set protocol, scope, type, and table ID using libnl functions + p_val->set_protocol(rtnl_route_get_protocol(route)); + p_val->set_scope(rtnl_route_get_scope(route)); + p_val->set_type(rtnl_route_get_type(route)); + p_val->set_table_id(rtnl_route_get_table(route)); + + // Set destination mask and prefix length + struct nl_addr *dst = rtnl_route_get_dst(route); + if (dst != nullptr) { + in_addr_t dst_mask = htonl(VMA_NETMASK(nl_addr_get_prefixlen(dst))); + p_val->set_dst_mask(dst_mask); + p_val->set_dst_pref_len(nl_addr_get_prefixlen(dst)); + } - len = RTM_PAYLOAD(nl_header); - rt_attribute = (struct rtattr *) RTM_RTA(rt_msg); + parse_attr(route, p_val); - for (;RTA_OK(rt_attribute, len);rt_attribute=RTA_NEXT(rt_attribute,len)) { - parse_attr(rt_attribute, p_val); - } p_val->set_state(true); p_val->set_str(); return true; } -void route_table_mgr::parse_attr(struct rtattr *rt_attribute, route_val *p_val) +void route_table_mgr::parse_attr(struct rtnl_route *route, route_val *p_val) { - switch (rt_attribute->rta_type) { - case RTA_DST: - p_val->set_dst_addr(*(in_addr_t *)RTA_DATA(rt_attribute)); - break; - // next hop IPv4 address - case RTA_GATEWAY: - p_val->set_gw(*(in_addr_t *)RTA_DATA(rt_attribute)); - break; - // unique ID associated with the network interface - case RTA_OIF: - p_val->set_if_index(*(int *)RTA_DATA(rt_attribute)); - char if_name[IFNAMSIZ]; - if_indextoname(p_val->get_if_index(),if_name); - p_val->set_if_name(if_name); - break; - case RTA_SRC: - case RTA_PREFSRC: - p_val->set_src_addr(*(in_addr_t *)RTA_DATA(rt_attribute)); - break; - case RTA_TABLE: - p_val->set_table_id(*(uint32_t *)RTA_DATA(rt_attribute)); - break; - case RTA_METRICS: - { - struct rtattr *rta = (struct rtattr *)RTA_DATA(rt_attribute); - int len = RTA_PAYLOAD(rt_attribute); - uint16_t type; - while (RTA_OK(rta, len)) { - type = rta->rta_type; - switch (type) { - case RTAX_MTU: - p_val->set_mtu(*(uint32_t *)RTA_DATA(rta)); - break; - default: - rt_mgr_logdbg("got unexpected METRICS %d %x", - type, *(uint32_t *)RTA_DATA(rta)); - break; - } - rta = RTA_NEXT(rta, len); - } - break; - } - default: - rt_mgr_logdbg("got unexpected type %d %x", rt_attribute->rta_type, - *(uint32_t *)RTA_DATA(rt_attribute)); - break; - } + struct nl_addr *addr; + + // Destination Address + addr = rtnl_route_get_dst(route); + if (addr) { + p_val->set_dst_addr(*(in_addr_t *)nl_addr_get_binary_addr(addr)); + } + + // Gateway Address (Next Hop) + struct rtnl_nexthop *nh = rtnl_route_nexthop_n(route, 0); // Assuming the first nexthop + if (nh) { + addr = rtnl_route_nh_get_gateway(nh); + if (addr) { + p_val->set_gw(*(in_addr_t *)nl_addr_get_binary_addr(addr)); + } + } + + // Output Interface Index and Name + const int if_index = rtnl_route_nh_get_ifindex(nh); + if (if_index > 0) { + p_val->set_if_index(if_index); + + char if_name[IFNAMSIZ] = {0}; + if_indextoname(if_index, if_name); + p_val->set_if_name(if_name); + } + + // Source Address + addr = rtnl_route_get_pref_src(route); + if (addr) { + p_val->set_src_addr(*(in_addr_t *)nl_addr_get_binary_addr(addr)); + } + + // Table ID + int table_id = rtnl_route_get_table(route); + p_val->set_table_id(table_id); + + // Metrics (e.g., MTU) + uint32_t mtu = 0; + int get_metric_result = rtnl_route_get_metric(route, RTAX_MTU, &mtu); + if (get_metric_result == 0) { + if (mtu > 0) { + p_val->set_mtu(mtu); + } + } } bool route_table_mgr::find_route_val(in_addr_t &dst, uint32_t table_id, route_val* &p_val) diff --git a/src/vma/proto/route_table_mgr.h b/src/vma/proto/route_table_mgr.h index fb0310f13..7287a19c3 100644 --- a/src/vma/proto/route_table_mgr.h +++ b/src/vma/proto/route_table_mgr.h @@ -69,7 +69,7 @@ class route_table_mgr : public netlink_socket_mgr, public cache_table virtual void notify_cb(event *ev); protected: - virtual bool parse_enrty(nlmsghdr *nl_header, route_val *p_val); + virtual bool parse_entry(struct nl_object * nl_obj, void *p_val_context); private: // in constructor creates route_entry for each net_dev, to receive events in case there are no other route_entrys @@ -79,7 +79,7 @@ class route_table_mgr : public netlink_socket_mgr, public cache_table // save current main rt table void update_tbl(); - void parse_attr(struct rtattr *rt_attribute, route_val *p_val); + void parse_attr(struct rtnl_route *route, route_val *p_val); void rt_mgr_update_source_ip(); diff --git a/src/vma/proto/rule_table_mgr.cpp b/src/vma/proto/rule_table_mgr.cpp index 32847014f..8811a8f89 100644 --- a/src/vma/proto/rule_table_mgr.cpp +++ b/src/vma/proto/rule_table_mgr.cpp @@ -53,7 +53,6 @@ #include "rule_table_mgr.h" #include "vma/sock/socket_fd_api.h" #include "vma/sock/sock-redirect.h" -#include "ip_address.h" // debugging macros #define MODULE_NAME "rrm:" @@ -97,31 +96,32 @@ void rule_table_mgr::update_tbl() // nl_header : object that contain rule entry. // p_val : custom object that contain parsed rule data. // return true if its not related to local or default table, false otherwise. -bool rule_table_mgr::parse_enrty(nlmsghdr *nl_header, rule_val *p_val) +bool rule_table_mgr::parse_entry(struct nl_object * nl_obj, void *p_val_context) { - int len; - struct rtmsg *rt_msg; - struct rtattr *rt_attribute; + int err = 0; - // get rule entry header - rt_msg = (struct rtmsg *) NLMSG_DATA(nl_header); + rule_val* p_val = static_cast(p_val_context); + // Cast the generic nl_object to a specific route or rule object + struct rtnl_rule *rule = reinterpret_cast(nl_obj); - // we are not concerned about the local and default rule table - if (rt_msg->rtm_family != AF_INET || rt_msg->rtm_table == RT_TABLE_LOCAL) - return false; + uint32_t table_id = rtnl_rule_get_table(rule); + if (rtnl_rule_get_family(rule) != AF_INET || table_id == RT_TABLE_LOCAL) + return false; - p_val->set_protocol(rt_msg->rtm_protocol); - p_val->set_scope(rt_msg->rtm_scope); - p_val->set_type(rt_msg->rtm_type); - p_val->set_tos(rt_msg->rtm_tos); - p_val->set_table_id(rt_msg->rtm_table); + // Set rule properties in p_val using libnl getters + uint8_t protocol = 0; + err = rtnl_rule_get_protocol(rule, &protocol); + if (err < 0) + { + throw_vma_exception("Failed to get rule protocol"); + } - len = RTM_PAYLOAD(nl_header); - rt_attribute = (struct rtattr *) RTM_RTA(rt_msg); + p_val->set_protocol(protocol); + p_val->set_tos(rtnl_rule_get_dsfield(rule)); + p_val->set_table_id(table_id); + + parse_attr(rule, p_val); - for (;RTA_OK(rt_attribute, len);rt_attribute=RTA_NEXT(rt_attribute,len)) { - parse_attr(rt_attribute, p_val); - } p_val->set_state(true); p_val->set_str(); return true; @@ -131,33 +131,45 @@ bool rule_table_mgr::parse_enrty(nlmsghdr *nl_header, rule_val *p_val) // Parameters: // rt_attribute : object that contain rule attribute. // p_val : custom object that contain parsed rule data. -void rule_table_mgr::parse_attr(struct rtattr *rt_attribute, rule_val *p_val) +void rule_table_mgr::parse_attr(struct rtnl_rule *rule, rule_val *p_val) { - switch (rt_attribute->rta_type) { - case FRA_PRIORITY: - p_val->set_priority(*(uint32_t *)RTA_DATA(rt_attribute)); - break; - case FRA_DST: - p_val->set_dst_addr(*(in_addr_t *)RTA_DATA(rt_attribute)); - break; - case FRA_SRC: - p_val->set_src_addr(*(in_addr_t *)RTA_DATA(rt_attribute)); - break; - case FRA_IFNAME: - p_val->set_iif_name((char *)RTA_DATA(rt_attribute)); - break; - case FRA_TABLE: - p_val->set_table_id(*(uint32_t *)RTA_DATA(rt_attribute)); - break; + // FRA_PRIORITY: Rule Priority + uint32_t priority = rtnl_rule_get_prio(rule); + if (priority) { + p_val->set_priority(priority); + } + + // FRA_DST: Destination Address + struct nl_addr *dst = rtnl_rule_get_dst(rule); + if (dst) { + p_val->set_dst_addr(*(in_addr_t *)nl_addr_get_binary_addr(dst)); + } + + // FRA_SRC: Source Address + struct nl_addr *src = rtnl_rule_get_src(rule); + if (src) { + p_val->set_src_addr(*(in_addr_t *)nl_addr_get_binary_addr(src)); + } + + // FRA_IFNAME: Input Interface Name + char *iif_name = rtnl_rule_get_iif(rule); + if (iif_name) { + p_val->set_iif_name(iif_name); + } + + // FRA_TABLE: Table ID + uint32_t table_id = rtnl_rule_get_table(rule); + if (table_id) { + p_val->set_table_id(table_id); + } + #if DEFINED_FRA_OIFNAME - case FRA_OIFNAME: - p_val->set_oif_name((char *)RTA_DATA(rt_attribute)); - break; + // FRA_OIFNAME: Output Interface Name (if available) + char *oif_name = rtnl_rule_get_oif(rule); + if (oif_name) { + p_val->set_oif_name(oif_name); + } #endif - default: - rr_mgr_logdbg("got undetected rta_type %d %x", rt_attribute->rta_type, *(uint32_t *)RTA_DATA(rt_attribute)); - break; - } } diff --git a/src/vma/proto/rule_table_mgr.h b/src/vma/proto/rule_table_mgr.h index f3f6fdcf7..0713c8084 100644 --- a/src/vma/proto/rule_table_mgr.h +++ b/src/vma/proto/rule_table_mgr.h @@ -54,12 +54,12 @@ class rule_table_mgr : public netlink_socket_mgr, public cache_table_m bool rule_resolve(route_rule_table_key key, std::deque &table_id_list); protected: - virtual bool parse_enrty(nlmsghdr *nl_header, rule_val *p_val); + virtual bool parse_entry(struct nl_object * nl_obj, void *p_val_context); virtual void update_tbl(); private: - void parse_attr(struct rtattr *rt_attribute, rule_val *p_val); + void parse_attr(struct rtnl_rule *rule, rule_val *p_val); bool find_rule_val(route_rule_table_key key, std::deque* &p_val); bool is_matching_rule(route_rule_table_key rrk, rule_val* p_val); diff --git a/src/vma/proto/rule_val.cpp b/src/vma/proto/rule_val.cpp index 3c56f4561..dd09dd3c5 100644 --- a/src/vma/proto/rule_val.cpp +++ b/src/vma/proto/rule_val.cpp @@ -47,8 +47,6 @@ rule_val::rule_val(): cache_observer() { m_protocol = 0; - m_scope = 0; - m_type = 0; m_dst_addr = 0; m_src_addr = 0; memset(m_oif_name, 0, IFNAMSIZ * sizeof(char)); diff --git a/src/vma/proto/rule_val.h b/src/vma/proto/rule_val.h index 9d6f8d200..759fd9df2 100644 --- a/src/vma/proto/rule_val.h +++ b/src/vma/proto/rule_val.h @@ -53,8 +53,6 @@ class rule_val : public cache_observer inline void set_dst_addr(in_addr_t const &dst_addr) { m_dst_addr = dst_addr; }; inline void set_src_addr(in_addr_t const &src_addr) { m_src_addr = src_addr; }; inline void set_protocol(unsigned char protocol) { m_protocol = protocol; }; - inline void set_scope(unsigned char scope) { m_scope = scope; }; - inline void set_type(unsigned char type) { m_type = type; }; inline void set_tos(unsigned char tos) { m_tos = tos; }; inline void set_table_id(uint32_t table_id) { m_table_id = table_id; }; inline void set_iif_name(char *iif_name) { memcpy(m_iif_name, iif_name, IFNAMSIZ); }; @@ -79,8 +77,6 @@ class rule_val : public cache_observer private: unsigned char m_protocol; - unsigned char m_scope; - unsigned char m_type; unsigned char m_tos; union {