From a17a2d50247f02e2682de5afb8008f7dfa3e3326 Mon Sep 17 00:00:00 2001 From: kimden <23140380+kimden@users.noreply.github.com> Date: Mon, 1 Jul 2024 18:37:20 +0400 Subject: [PATCH 1/2] Separate SQL functions from server lobby I also attempted to reuse some code that checks if a peer is banned --- src/network/database_connector.cpp | 961 ++++++++++++++++++++ src/network/database_connector.hpp | 121 +++ src/network/protocols/server_lobby.cpp | 1139 ++++-------------------- src/network/protocols/server_lobby.hpp | 38 +- 4 files changed, 1240 insertions(+), 1019 deletions(-) create mode 100644 src/network/database_connector.cpp create mode 100644 src/network/database_connector.hpp diff --git a/src/network/database_connector.cpp b/src/network/database_connector.cpp new file mode 100644 index 00000000000..93844c897c7 --- /dev/null +++ b/src/network/database_connector.cpp @@ -0,0 +1,961 @@ +// +// SuperTuxKart - a fun racing game with go-kart +// Copyright (C) 2013-2015 SuperTuxKart-Team +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of the GNU General Public License +// as published by the Free Software Foundation; either version 3 +// of the License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. + +#ifdef ENABLE_SQLITE3 + +#include "network/database_connector.hpp" + +#include "network/network_player_profile.hpp" +#include "network/server_config.hpp" +#include "network/socket_address.hpp" +#include "network/stk_host.hpp" +#include "network/stk_ipv6.hpp" +#include "network/stk_peer.hpp" +#include "utils/log.hpp" + +//----------------------------------------------------------------------------- +void DatabaseConnector::initDatabase() +{ + m_last_poll_db_time = StkTime::getMonoTimeMs(); + m_db = NULL; + m_ip_ban_table_exists = false; + m_ipv6_ban_table_exists = false; + m_online_id_ban_table_exists = false; + m_ip_geolocation_table_exists = false; + m_ipv6_geolocation_table_exists = false; + m_player_reports_table_exists = false; + if (!ServerConfig::m_sql_management) + return; + const std::string& path = ServerConfig::getConfigDirectory() + "/" + + ServerConfig::m_database_file.c_str(); + int ret = sqlite3_open_v2(path.c_str(), &m_db, + SQLITE_OPEN_SHAREDCACHE | SQLITE_OPEN_FULLMUTEX | + SQLITE_OPEN_READWRITE, NULL); + if (ret != SQLITE_OK) + { + Log::error("ServerLobby", "Cannot open database: %s.", + sqlite3_errmsg(m_db)); + sqlite3_close(m_db); + m_db = NULL; + return; + } + sqlite3_busy_handler(m_db, [](void* data, int retry) + { + int retry_count = ServerConfig::m_database_timeout / 100; + if (retry < retry_count) + { + sqlite3_sleep(100); + // Return non-zero to let caller retry again + return 1; + } + // Return zero to let caller return SQLITE_BUSY immediately + return 0; + }, NULL); + sqlite3_create_function(m_db, "insideIPv6CIDR", 2, SQLITE_UTF8, NULL, + &insideIPv6CIDRSQL, NULL, NULL); + sqlite3_create_function(m_db, "upperIPv6", 1, SQLITE_UTF8, NULL, + &upperIPv6SQL, NULL, NULL); + checkTableExists(ServerConfig::m_ip_ban_table, m_ip_ban_table_exists); + checkTableExists(ServerConfig::m_ipv6_ban_table, m_ipv6_ban_table_exists); + checkTableExists(ServerConfig::m_online_id_ban_table, + m_online_id_ban_table_exists); + checkTableExists(ServerConfig::m_player_reports_table, + m_player_reports_table_exists); + checkTableExists(ServerConfig::m_ip_geolocation_table, + m_ip_geolocation_table_exists); + checkTableExists(ServerConfig::m_ipv6_geolocation_table, + m_ipv6_geolocation_table_exists); +} // initDatabase + +//----------------------------------------------------------------------------- +void DatabaseConnector::destroyDatabase() +{ + auto peers = STKHost::get()->getPeers(); + for (auto& peer : peers) + writeDisconnectInfoTable(peer.get()); + if (m_db != NULL) + sqlite3_close(m_db); +} // destroyDatabase + +//----------------------------------------------------------------------------- +/** Run simple query with write lock waiting and optional function, this + * function has no callback for the return (if any) by the query. + * Return true if no error occurs + */ +bool DatabaseConnector::easySQLQuery(const std::string& query, + std::function bind_function) const +{ + if (!m_db) + return false; + sqlite3_stmt* stmt = NULL; + int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); + if (ret == SQLITE_OK) + { + if (bind_function) + bind_function(stmt); + ret = sqlite3_step(stmt); + ret = sqlite3_finalize(stmt); + if (ret != SQLITE_OK) + { + Log::error("DatabaseConnector", + "Error finalize database for easy query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + return false; + } + } + else + { + Log::error("DatabaseConnector", + "Error preparing database for easy query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + return false; + } + return true; +} // easySQLQuery + +//----------------------------------------------------------------------------- +/* Write true to result if table name exists in database. */ +void DatabaseConnector::checkTableExists(const std::string& table, bool& result) +{ + if (!m_db) + return; + sqlite3_stmt* stmt = NULL; + if (!table.empty()) + { + std::string query = StringUtils::insertValues( + "SELECT count(type) FROM sqlite_master " + "WHERE type='table' AND name='%s';", table.c_str()); + + int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); + if (ret == SQLITE_OK) + { + ret = sqlite3_step(stmt); + if (ret == SQLITE_ROW) + { + int number = sqlite3_column_int(stmt, 0); + if (number == 1) + { + Log::info("DatabaseConnector", "Table named %s will be used.", + table.c_str()); + result = true; + } + } + ret = sqlite3_finalize(stmt); + if (ret != SQLITE_OK) + { + Log::error("DatabaseConnector", + "Error finalize database for query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + } + } + } + if (!result && !table.empty()) + { + Log::warn("DatabaseConnector", "Table named %s not found in database.", + table.c_str()); + } +} // checkTableExists + +//----------------------------------------------------------------------------- +std::string DatabaseConnector::ip2Country(const SocketAddress& addr) const +{ + if (!m_db || !m_ip_geolocation_table_exists || addr.isLAN()) + return ""; + + std::string cc_code; + std::string query = StringUtils::insertValues( + "SELECT country_code FROM %s " + "WHERE `ip_start` <= %d AND `ip_end` >= %d " + "ORDER BY `ip_start` DESC LIMIT 1;", + ServerConfig::m_ip_geolocation_table.c_str(), addr.getIP(), + addr.getIP()); + + sqlite3_stmt* stmt = NULL; + int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); + if (ret == SQLITE_OK) + { + ret = sqlite3_step(stmt); + if (ret == SQLITE_ROW) + { + const char* country_code = (char*)sqlite3_column_text(stmt, 0); + cc_code = country_code; + } + ret = sqlite3_finalize(stmt); + if (ret != SQLITE_OK) + { + Log::error("DatabaseConnector", + "Error finalize database for query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + } + } + else + { + Log::error("DatabaseConnector", "Error preparing database for query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + return ""; + } + return cc_code; +} // ip2Country + +//----------------------------------------------------------------------------- +std::string DatabaseConnector::ipv62Country(const SocketAddress& addr) const +{ + if (!m_db || !m_ipv6_geolocation_table_exists) + return ""; + + std::string cc_code; + const std::string& ipv6 = addr.toString(false/*show_port*/); + std::string query = StringUtils::insertValues( + "SELECT country_code FROM %s " + "WHERE `ip_start` <= upperIPv6(\"%s\") AND `ip_end` >= upperIPv6(\"%s\") " + "ORDER BY `ip_start` DESC LIMIT 1;", + ServerConfig::m_ipv6_geolocation_table.c_str(), ipv6.c_str(), + ipv6.c_str()); + + sqlite3_stmt* stmt = NULL; + int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); + if (ret == SQLITE_OK) + { + ret = sqlite3_step(stmt); + if (ret == SQLITE_ROW) + { + const char* country_code = (char*)sqlite3_column_text(stmt, 0); + cc_code = country_code; + } + ret = sqlite3_finalize(stmt); + if (ret != SQLITE_OK) + { + Log::error("DatabaseConnector", + "Error finalize database for query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + } + } + else + { + Log::error("DatabaseConnector", "Error preparing database for query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + return ""; + } + return cc_code; +} // ipv62Country + +// ---------------------------------------------------------------------------- +void DatabaseConnector::upperIPv6SQL(sqlite3_context* context, int argc, + sqlite3_value** argv) +{ + if (argc != 1) + { + sqlite3_result_int64(context, 0); + return; + } + + char* ipv6 = (char*)sqlite3_value_text(argv[0]); + if (ipv6 == NULL) + { + sqlite3_result_int64(context, 0); + return; + } + sqlite3_result_int64(context, upperIPv6(ipv6)); +} + +// ---------------------------------------------------------------------------- +void DatabaseConnector::insideIPv6CIDRSQL(sqlite3_context* context, int argc, + sqlite3_value** argv) +{ + if (argc != 2) + { + sqlite3_result_int(context, 0); + return; + } + + char* ipv6_cidr = (char*)sqlite3_value_text(argv[0]); + char* ipv6_in = (char*)sqlite3_value_text(argv[1]); + if (ipv6_cidr == NULL || ipv6_in == NULL) + { + sqlite3_result_int(context, 0); + return; + } + sqlite3_result_int(context, insideIPv6CIDR(ipv6_cidr, ipv6_in)); +} // insideIPv6CIDRSQL + +// ---------------------------------------------------------------------------- +/* +Copy below code so it can be use as loadable extension to be used in sqlite3 +command interface (together with andIPv6 and insideIPv6CIDR from stk_ipv6) + +#include "sqlite3ext.h" +SQLITE_EXTENSION_INIT1 +// ---------------------------------------------------------------------------- +sqlite3_extension_init(sqlite3* db, char** pzErrMsg, + const sqlite3_api_routines* pApi) +{ + SQLITE_EXTENSION_INIT2(pApi) + sqlite3_create_function(db, "insideIPv6CIDR", 2, SQLITE_UTF8, NULL, + insideIPv6CIDRSQL, NULL, NULL); + sqlite3_create_function(db, "upperIPv6", 1, SQLITE_UTF8, 0, upperIPv6SQL, + 0, 0); + return 0; +} // sqlite3_extension_init +*/ + +//----------------------------------------------------------------------------- +void DatabaseConnector::writeDisconnectInfoTable(STKPeer* peer) +{ + if (m_server_stats_table.empty()) + return; + std::string query = StringUtils::insertValues( + "UPDATE %s SET disconnected_time = datetime('now'), " + "ping = %d, packet_loss = %d " + "WHERE host_id = %u;", m_server_stats_table.c_str(), + peer->getAveragePing(), peer->getPacketLoss(), + peer->getHostId()); + easySQLQuery(query); +} // writeDisconnectInfoTable + +//----------------------------------------------------------------------------- + +void DatabaseConnector::initServerStatsTable() +{ + if (!ServerConfig::m_sql_management || !m_db) + return; + std::string table_name = std::string("v") + + StringUtils::toString(ServerConfig::m_server_db_version) + "_" + + ServerConfig::m_server_uid + "_stats"; + + std::ostringstream oss; + oss << "CREATE TABLE IF NOT EXISTS " << table_name << " (\n" + " host_id INTEGER UNSIGNED NOT NULL PRIMARY KEY, -- Unique host id in STKHost of each connection session for a STKPeer\n" + " ip INTEGER UNSIGNED NOT NULL, -- IP decimal of host\n"; + if (ServerConfig::m_ipv6_connection) + oss << " ipv6 TEXT NOT NULL DEFAULT '', -- IPv6 (if exists) in string of host\n"; + oss << " port INTEGER UNSIGNED NOT NULL, -- Port of host\n" + " online_id INTEGER UNSIGNED NOT NULL, -- Online if of the host (0 for offline account)\n" + " username TEXT NOT NULL, -- First player name in the host (if the host has splitscreen player)\n" + " player_num INTEGER UNSIGNED NOT NULL, -- Number of player(s) from the host, more than 1 if it has splitscreen player\n" + " country_code TEXT NULL DEFAULT NULL, -- 2-letter country code of the host\n" + " version TEXT NOT NULL, -- SuperTuxKart version of the host\n" + " os TEXT NOT NULL, -- Operating system of the host\n" + " connected_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Time when connected\n" + " disconnected_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Time when disconnected (saved when disconnected)\n" + " ping INTEGER UNSIGNED NOT NULL DEFAULT 0, -- Ping of the host\n" + " packet_loss INTEGER NOT NULL DEFAULT 0 -- Mean packet loss count from ENet (saved when disconnected)\n" + ") WITHOUT ROWID;"; + std::string query = oss.str(); + sqlite3_stmt* stmt = NULL; + int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); + if (ret == SQLITE_OK) + { + ret = sqlite3_step(stmt); + ret = sqlite3_finalize(stmt); + if (ret == SQLITE_OK) + m_server_stats_table = table_name; + else + { + Log::error("DatabaseConnector", + "Error finalize database for query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + } + } + else + { + Log::error("DatabaseConnector", "Error preparing database for query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + } + if (m_server_stats_table.empty()) + return; + + // Extra default table _countries: + // Server owner need to initialise this table himself, check NETWORKING.md + std::string country_table_name = std::string("v") + StringUtils::toString( + ServerConfig::m_server_db_version) + "_countries"; + query = StringUtils::insertValues( + "CREATE TABLE IF NOT EXISTS %s (\n" + " country_code TEXT NOT NULL PRIMARY KEY UNIQUE, -- Unique 2-letter country code\n" + " country_flag TEXT NOT NULL, -- Unicode country flag representation of 2-letter country code\n" + " country_name TEXT NOT NULL -- Readable name of this country\n" + ") WITHOUT ROWID;", country_table_name.c_str()); + easySQLQuery(query); + + // Default views: + // _full_stats + // Full stats with ip in human readable format and time played of each + // players in minutes + std::string full_stats_view_name = std::string("v") + + StringUtils::toString(ServerConfig::m_server_db_version) + "_" + + ServerConfig::m_server_uid + "_full_stats"; + oss.str(""); + oss << "CREATE VIEW IF NOT EXISTS " << full_stats_view_name << " AS\n" + << " SELECT host_id, ip,\n" + << " ((ip >> 24) & 255) ||'.'|| ((ip >> 16) & 255) ||'.'|| ((ip >> 8) & 255) ||'.'|| ((ip ) & 255) AS ip_readable,\n"; + if (ServerConfig::m_ipv6_connection) + oss << " ipv6,"; + oss << " port, online_id, username, player_num,\n" + << " " << m_server_stats_table << ".country_code AS country_code, country_flag, country_name, version, os,\n" + << " ROUND((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0, 2) AS time_played,\n" + << " connected_time, disconnected_time, ping, packet_loss FROM " << m_server_stats_table << "\n" + << " LEFT JOIN " << country_table_name << " ON " + << country_table_name << ".country_code = " << m_server_stats_table << ".country_code\n" + << " ORDER BY connected_time DESC;"; + query = oss.str(); + easySQLQuery(query); + + // _current_players + // Current players in server with ip in human readable format and time + // played of each players in minutes + std::string current_players_view_name = std::string("v") + + StringUtils::toString(ServerConfig::m_server_db_version) + "_" + + ServerConfig::m_server_uid + "_current_players"; + oss.str(""); + oss.clear(); + oss << "CREATE VIEW IF NOT EXISTS " << current_players_view_name << " AS\n" + << " SELECT host_id, ip,\n" + << " ((ip >> 24) & 255) ||'.'|| ((ip >> 16) & 255) ||'.'|| ((ip >> 8) & 255) ||'.'|| ((ip ) & 255) AS ip_readable,\n"; + if (ServerConfig::m_ipv6_connection) + oss << " ipv6,"; + oss << " port, online_id, username, player_num,\n" + << " " << m_server_stats_table << ".country_code AS country_code, country_flag, country_name, version, os,\n" + << " ROUND((STRFTIME(\"%s\", 'now') - STRFTIME(\"%s\", connected_time)) / 60.0, 2) AS time_played,\n" + << " connected_time, ping FROM " << m_server_stats_table << "\n" + << " LEFT JOIN " << country_table_name << " ON " + << country_table_name << ".country_code = " << m_server_stats_table << ".country_code\n" + << " WHERE connected_time = disconnected_time;"; + query = oss.str(); + easySQLQuery(query); + + // _player_stats + // All players with online id and username with their time played stats + // in this server since creation of this database + // If sqlite supports window functions (since 3.25), it will include last session player info (ip, country, ping...) + std::string player_stats_view_name = std::string("v") + + StringUtils::toString(ServerConfig::m_server_db_version) + "_" + + ServerConfig::m_server_uid + "_player_stats"; + oss.str(""); + oss.clear(); + if (sqlite3_libversion_number() < 3025000) + { + oss << "CREATE VIEW IF NOT EXISTS " << player_stats_view_name << " AS\n" + << " SELECT online_id, username, COUNT(online_id) AS num_connections,\n" + << " MIN(connected_time) AS first_connected_time,\n" + << " MAX(connected_time) AS last_connected_time,\n" + << " ROUND(SUM((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS total_time_played,\n" + << " ROUND(AVG((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS average_time_played,\n" + << " ROUND(MIN((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS min_time_played,\n" + << " ROUND(MAX((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS max_time_played\n" + << " FROM " << m_server_stats_table << "\n" + << " WHERE online_id != 0 GROUP BY online_id ORDER BY num_connections DESC;"; + } + else + { + oss << "CREATE VIEW IF NOT EXISTS " << player_stats_view_name << " AS\n" + << " SELECT a.online_id, a.username, a.ip, a.ip_readable,\n"; + if (ServerConfig::m_ipv6_connection) + oss << " a.ipv6,"; + oss << " a.port, a.player_num,\n" + << " a.country_code, a.country_flag, a.country_name, a.version, a.os, a.ping, a.packet_loss,\n" + << " b.num_connections, b.first_connected_time, b.first_disconnected_time,\n" + << " a.connected_time AS last_connected_time, a.disconnected_time AS last_disconnected_time,\n" + << " a.time_played AS last_time_played, b.total_time_played, b.average_time_played,\n" + << " b.min_time_played, b.max_time_played\n" + << " FROM\n" + << " (\n" + << " SELECT *,\n" + << " ROW_NUMBER() OVER\n" + << " (\n" + << " PARTITION BY online_id\n" + << " ORDER BY connected_time DESC\n" + << " ) RowNum\n" + << " FROM " << full_stats_view_name << " where online_id != 0\n" + << " ) as a\n" + << " JOIN\n" + << " (\n" + << " SELECT online_id, COUNT(online_id) AS num_connections,\n" + << " MIN(connected_time) AS first_connected_time,\n" + << " MIN(disconnected_time) AS first_disconnected_time,\n" + << " ROUND(SUM((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS total_time_played,\n" + << " ROUND(AVG((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS average_time_played,\n" + << " ROUND(MIN((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS min_time_played,\n" + << " ROUND(MAX((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS max_time_played\n" + << " FROM " << m_server_stats_table << " WHERE online_id != 0 GROUP BY online_id\n" + << " ) AS b\n" + << " ON b.online_id = a.online_id\n" + << " WHERE RowNum = 1 ORDER BY num_connections DESC;\n"; + } + query = oss.str(); + easySQLQuery(query); + + uint32_t last_host_id = 0; + query = StringUtils::insertValues("SELECT MAX(host_id) FROM %s;", + m_server_stats_table.c_str()); + ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); + if (ret == SQLITE_OK) + { + ret = sqlite3_step(stmt); + if (ret == SQLITE_ROW && sqlite3_column_type(stmt, 0) != SQLITE_NULL) + { + last_host_id = (unsigned)sqlite3_column_int64(stmt, 0); + Log::info("DatabaseConnector", "%u was last server session max host id.", + last_host_id); + } + ret = sqlite3_finalize(stmt); + if (ret != SQLITE_OK) + { + Log::error("DatabaseConnector", + "Error finalize database for query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + m_server_stats_table = ""; + } + } + else + { + Log::error("DatabaseConnector", "Error preparing database for query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + m_server_stats_table = ""; + } + STKHost::get()->setNextHostId(last_host_id); + + // Update disconnected time (if stk crashed it will not be written) + query = StringUtils::insertValues( + "UPDATE %s SET disconnected_time = datetime('now') " + "WHERE connected_time = disconnected_time;", + m_server_stats_table.c_str()); + easySQLQuery(query); +} // initServerStatsTable + +//----------------------------------------------------------------------------- +bool DatabaseConnector::writeReport(STKPeer* reporter, std::shared_ptr reporter_npp, + STKPeer* reporting, std::shared_ptr reporting_npp, + irr::core::stringw& info) +{ + std::string query; + if (ServerConfig::m_ipv6_connection) + { + query = StringUtils::insertValues( + "INSERT INTO %s " + "(server_uid, reporter_ip, reporter_ipv6, reporter_online_id, reporter_username, " + "info, reporting_ip, reporting_ipv6, reporting_online_id, reporting_username) " + "VALUES (?, %u, \"%s\", %u, ?, ?, %u, \"%s\", %u, ?);", + ServerConfig::m_player_reports_table.c_str(), + !reporter->getAddress().isIPv6() ? reporter->getAddress().getIP() : 0, + reporter->getAddress().isIPv6() ? reporter->getAddress().toString(false) : "", + reporter_npp->getOnlineId(), + !reporting->getAddress().isIPv6() ? reporting->getAddress().getIP() : 0, + reporting->getAddress().isIPv6() ? reporting->getAddress().toString(false) : "", + reporting_npp->getOnlineId()); + } + else + { + query = StringUtils::insertValues( + "INSERT INTO %s " + "(server_uid, reporter_ip, reporter_online_id, reporter_username, " + "info, reporting_ip, reporting_online_id, reporting_username) " + "VALUES (?, %u, %u, ?, ?, %u, %u, ?);", + ServerConfig::m_player_reports_table.c_str(), + reporter->getAddress().getIP(), reporter_npp->getOnlineId(), + reporting->getAddress().getIP(), reporting_npp->getOnlineId()); + } + return easySQLQuery(query, + [reporter_npp, reporting_npp, info](sqlite3_stmt* stmt) + { + // SQLITE_TRANSIENT to copy string + if (sqlite3_bind_text(stmt, 1, ServerConfig::m_server_uid.c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind %s.", + ServerConfig::m_server_uid.c_str()); + } + if (sqlite3_bind_text(stmt, 2, + StringUtils::wideToUtf8(reporter_npp->getName()).c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind %s.", + StringUtils::wideToUtf8(reporter_npp->getName()).c_str()); + } + if (sqlite3_bind_text(stmt, 3, + StringUtils::wideToUtf8(info).c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind %s.", + StringUtils::wideToUtf8(info).c_str()); + } + if (sqlite3_bind_text(stmt, 4, + StringUtils::wideToUtf8(reporting_npp->getName()).c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind %s.", + StringUtils::wideToUtf8(reporting_npp->getName()).c_str()); + } + }); +} // writeReport + +//----------------------------------------------------------------------------- +std::vector +DatabaseConnector::getIpBanTableData(uint32_t ip) const +{ + std::vector result; + if (!m_ip_ban_table_exists) + { + return result; + } + bool single_ip = (ip != 0); + std::ostringstream oss; + oss << "SELECT rowid, ip_start, ip_end, reason, description FROM "; + oss << (std::string)ServerConfig::m_ip_ban_table << " WHERE "; + if (single_ip) + oss << "ip_start <= " << ip << " AND ip_end >= " << ip << " AND "; + oss << "datetime('now') > datetime(starting_time) AND " + "(expired_days is NULL OR datetime" + "(starting_time, '+'||expired_days||' days') > datetime('now'))"; + if (single_ip) + oss << " LIMIT 1"; + oss << ";"; + std::string query = oss.str(); + sqlite3_exec(m_db, query.c_str(), + [](void* ptr, int count, char** data, char** columns) + { + std::vector* vec = (std::vector*)ptr; + IpBanTableData element; + if (!StringUtils::fromString(data[0], element.row_id)) + return 0; + if (!StringUtils::fromString(data[1], element.ip_start)) + return 0; + if (!StringUtils::fromString(data[2], element.ip_end)) + return 0; + element.reason = std::string(data[3]); + element.description = std::string(data[4]); + vec->push_back(element); + return 0; + }, &result, NULL); + return result; +} // getIpBanTableData + +//----------------------------------------------------------------------------- +void DatabaseConnector::increaseIpBanTriggerCount(uint32_t ip_start, uint32_t ip_end) const +{ + std::string query = StringUtils::insertValues( + "UPDATE %s SET trigger_count = trigger_count + 1, " + "last_trigger = datetime('now') " + "WHERE ip_start = %u AND ip_end = %u;", + ServerConfig::m_ip_ban_table.c_str(), ip_start, ip_end); + easySQLQuery(query); +} // getIpBanTableData + +//----------------------------------------------------------------------------- +std::vector +DatabaseConnector::getIpv6BanTableData(std::string ipv6) const +{ + std::vector result; + if (!m_ipv6_ban_table_exists) + { + return result; + } + bool single_ip = !ipv6.empty(); + std::ostringstream oss; + oss << "SELECT rowid, ipv6_cidr, reason, description FROM "; + oss << (std::string)ServerConfig::m_ipv6_ban_table; + oss << " WHERE "; + if (single_ip) + oss << "insideIPv6CIDR(ipv6_cidr, ?) = 1 AND "; + oss << "datetime('now') > datetime(starting_time) AND " + "(expired_days is NULL OR datetime" + "(starting_time, '+'||expired_days||' days') > datetime('now'))"; + if (single_ip) + oss << " LIMIT 1"; + oss << ";"; + std::string query = oss.str(); + + sqlite3_stmt* stmt = NULL; + int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); + if (ret == SQLITE_OK) + { + if (single_ip) + { + if (sqlite3_bind_text(stmt, 1, + ipv6.c_str(), -1, SQLITE_TRANSIENT) + != SQLITE_OK) + { + Log::error("DatabaseConnector", "Error binding ipv6 addr for query: %s", + sqlite3_errmsg(m_db)); + return result; + } + } + ret = sqlite3_step(stmt); + while (ret == SQLITE_ROW) + { + const char* rowid_cstr = (char*)sqlite3_column_text(stmt, 0); + const char* ipv6cidr_cstr = (char*)sqlite3_column_text(stmt, 1); + const char* reason_cstr = (char*)sqlite3_column_text(stmt, 2); + const char* description_cstr = (char*)sqlite3_column_text(stmt, 3); + Ipv6BanTableData element; + if (StringUtils::fromString(rowid_cstr, element.row_id)) + { + element.ipv6_cidr = std::string(ipv6cidr_cstr); + element.reason = std::string(reason_cstr); + element.description = std::string(description_cstr); + result.push_back(element); + } + ret = sqlite3_step(stmt); + } + ret = sqlite3_finalize(stmt); + if (ret != SQLITE_OK) + { + Log::error("DatabaseConnector", + "Error finalize database for query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + } + } + else + { + Log::error("DatabaseConnector", "Error preparing database for query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + return result; + } + return result; +} // getIpv6BanTableData + +//----------------------------------------------------------------------------- +void DatabaseConnector::increaseIpv6BanTriggerCount(const std::string& ipv6_cidr) const +{ + std::string query = StringUtils::insertValues( + "UPDATE %s SET trigger_count = trigger_count + 1, " + "last_trigger = datetime('now') " + "WHERE ipv6_cidr = ?;", ServerConfig::m_ipv6_ban_table.c_str()); + easySQLQuery(query, [ipv6_cidr](sqlite3_stmt* stmt) + { + if (sqlite3_bind_text(stmt, 1, ipv6_cidr.c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind %s.", + ipv6_cidr.c_str()); + } + }); +} // increaseIpv6BanTriggerCount + +//----------------------------------------------------------------------------- +std::vector +DatabaseConnector::getOnlineIdBanTableData(uint32_t online_id) const +{ + std::vector result; + if (!m_online_id_ban_table_exists) + { + return result; + } + bool single_id = (online_id != 0); + std::ostringstream oss; + oss << "SELECT rowid, online_id, reason, description FROM "; + oss << (std::string)ServerConfig::m_online_id_ban_table; + oss << " WHERE "; + if (single_id) + oss << "online_id = " << online_id << " AND "; + oss << "datetime('now') > datetime(starting_time) AND " + "(expired_days is NULL OR datetime" + "(starting_time, '+'||expired_days||' days') > datetime('now'))"; + if (single_id) + oss << " LIMIT 1"; + oss << ";"; + std::string query = oss.str(); + sqlite3_exec(m_db, query.c_str(), + [](void* ptr, int count, char** data, char** columns) + { + std::vector* vec = (std::vector*)ptr; + OnlineIdBanTableData element; + if (!StringUtils::fromString(data[0], element.row_id)) + return 0; + if (!StringUtils::fromString(data[1], element.online_id)) + return 0; + element.reason = std::string(data[2]); + element.description = std::string(data[3]); + vec->push_back(element); + return 0; + }, &result, NULL); + return result; +} // getOnlineIdBanTableData + +//----------------------------------------------------------------------------- +void DatabaseConnector::increaseOnlineIdBanTriggerCount(uint32_t online_id) const +{ + std::string query = StringUtils::insertValues( + "UPDATE %s SET trigger_count = trigger_count + 1, " + "last_trigger = datetime('now') " + "WHERE online_id = %u;", + ServerConfig::m_online_id_ban_table.c_str(), online_id); + easySQLQuery(query); +} // increaseOnlineIdBanTriggerCount + +//----------------------------------------------------------------------------- +void DatabaseConnector::clearOldReports() +{ + if (m_player_reports_table_exists && + ServerConfig::m_player_reports_expired_days != 0.0f) + { + std::string query = StringUtils::insertValues( + "DELETE FROM %s " + "WHERE datetime" + "(reported_time, '+%f days') < datetime('now');", + ServerConfig::m_player_reports_table.c_str(), + ServerConfig::m_player_reports_expired_days); + easySQLQuery(query); + } +} // clearOldReports + +//----------------------------------------------------------------------------- +void DatabaseConnector::setDisconnectionTimes(std::vector& present_hosts) +{ + if (!hasServerStatsTable()) + return; + std::ostringstream oss; + oss << "UPDATE " << m_server_stats_table + << " SET disconnected_time = datetime('now')" + << " WHERE connected_time = disconnected_time"; + if (present_hosts.empty()) + { + oss << ";"; + } + else + { + oss << " AND host_id NOT IN ("; + for (unsigned i = 0; i < present_hosts.size(); i++) + { + if (i > 0) + oss << ","; + oss << present_hosts[i]; + } + oss << ");"; + } + std::string query = oss.str(); + easySQLQuery(query); +} // setDisconnectionTimes + +//----------------------------------------------------------------------------- +void DatabaseConnector::saveAddressToIpBanTable(const SocketAddress& addr) +{ + if (addr.isIPv6() || !m_db || !m_ip_ban_table_exists) + return; + + std::string query = StringUtils::insertValues( + "INSERT INTO %s (ip_start, ip_end) " + "VALUES (%u, %u);", + ServerConfig::m_ip_ban_table.c_str(), addr.getIP(), addr.getIP()); + easySQLQuery(query); +} // saveAddressToIpBanTable + +//----------------------------------------------------------------------------- +void DatabaseConnector::onPlayerJoinQueries(std::shared_ptr peer, + uint32_t online_id, unsigned player_count, const std::string& country_code) +{ + if (m_server_stats_table.empty() || peer->isAIPeer()) + return; + std::string query; + if (ServerConfig::m_ipv6_connection && peer->getAddress().isIPv6()) + { + query = StringUtils::insertValues( + "INSERT INTO %s " + "(host_id, ip, ipv6 ,port, online_id, username, player_num, " + "country_code, version, os, ping) " + "VALUES (%u, 0, \"%s\" ,%u, %u, ?, %u, ?, ?, ?, %u);", + m_server_stats_table.c_str(), peer->getHostId(), + peer->getAddress().toString(false), peer->getAddress().getPort(), + online_id, player_count, peer->getAveragePing()); + } + else + { + query = StringUtils::insertValues( + "INSERT INTO %s " + "(host_id, ip, port, online_id, username, player_num, " + "country_code, version, os, ping) " + "VALUES (%u, %u, %u, %u, ?, %u, ?, ?, ?, %u);", + m_server_stats_table.c_str(), peer->getHostId(), + peer->getAddress().getIP(), peer->getAddress().getPort(), + online_id, player_count, peer->getAveragePing()); + } + easySQLQuery(query, [peer, country_code](sqlite3_stmt* stmt) + { + if (sqlite3_bind_text(stmt, 1, StringUtils::wideToUtf8( + peer->getPlayerProfiles()[0]->getName()).c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind %s.", + StringUtils::wideToUtf8( + peer->getPlayerProfiles()[0]->getName()).c_str()); + } + if (country_code.empty()) + { + if (sqlite3_bind_null(stmt, 2) != SQLITE_OK) + { + Log::error("easySQLQuery", + "Failed to bind NULL for country code."); + } + } + else + { + if (sqlite3_bind_text(stmt, 2, country_code.c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind country: %s.", + country_code.c_str()); + } + } + auto version_os = + StringUtils::extractVersionOS(peer->getUserVersion()); + if (sqlite3_bind_text(stmt, 3, version_os.first.c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind %s.", + version_os.first.c_str()); + } + if (sqlite3_bind_text(stmt, 4, version_os.second.c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind %s.", + version_os.second.c_str()); + } + }); +} // onPlayerJoinQueries + +//----------------------------------------------------------------------------- +void DatabaseConnector::listBanTable() +{ + if (!m_db) + return; + auto printer = [](void* data, int argc, char** argv, char** name) + { + for (int i = 0; i < argc; i++) + { + std::cout << name[i] << " = " << (argv[i] ? argv[i] : "NULL") + << "\n"; + } + std::cout << "\n"; + return 0; + }; + if (m_ip_ban_table_exists) + { + std::string query = "SELECT * FROM "; + query += ServerConfig::m_ip_ban_table; + query += ";"; + std::cout << "IP ban list:\n"; + sqlite3_exec(m_db, query.c_str(), printer, NULL, NULL); + } + if (m_online_id_ban_table_exists) + { + std::string query = "SELECT * FROM "; + query += ServerConfig::m_online_id_ban_table; + query += ";"; + std::cout << "Online Id ban list:\n"; + sqlite3_exec(m_db, query.c_str(), printer, NULL, NULL); + } +} // listBanTable +#endif // ENABLE_SQLITE3 diff --git a/src/network/database_connector.hpp b/src/network/database_connector.hpp new file mode 100644 index 00000000000..282fc859863 --- /dev/null +++ b/src/network/database_connector.hpp @@ -0,0 +1,121 @@ +// +// SuperTuxKart - a fun racing game with go-kart +// Copyright (C) 2024 SuperTuxKart-Team +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of the GNU General Public License +// as published by the Free Software Foundation; either version 3 +// of the License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. + +#ifdef ENABLE_SQLITE3 + +#ifndef DATABASE_CONNECTOR_HPP +#define DATABASE_CONNECTOR_HPP + +#include "utils/string_utils.hpp" +#include "utils/time.hpp" + +#include +#include +#include +#include +#include +#include + +class SocketAddress; +class STKPeer; +class NetworkPlayerProfile; + + +class DatabaseConnector +{ +private: + sqlite3* m_db; + std::string m_server_stats_table; + bool m_ip_ban_table_exists; + bool m_ipv6_ban_table_exists; + bool m_online_id_ban_table_exists; + bool m_ip_geolocation_table_exists; + bool m_ipv6_geolocation_table_exists; + bool m_player_reports_table_exists; + uint64_t m_last_poll_db_time; + +public: + struct IpBanTableData + { + int row_id; + uint32_t ip_start; + uint32_t ip_end; + std::string reason; + std::string description; + }; + struct Ipv6BanTableData { + int row_id; + std::string ipv6_cidr; + std::string reason; + std::string description; + }; + struct OnlineIdBanTableData { + int row_id; + uint32_t online_id; + std::string reason; + std::string description; + }; + void initDatabase(); + void destroyDatabase(); + + bool easySQLQuery(const std::string& query, + std::function bind_function = nullptr) const; + + void checkTableExists(const std::string& table, bool& result); + + std::string ip2Country(const SocketAddress& addr) const; + + std::string ipv62Country(const SocketAddress& addr) const; + + static void upperIPv6SQL(sqlite3_context* context, int argc, + sqlite3_value** argv); + static void insideIPv6CIDRSQL(sqlite3_context* context, int argc, + sqlite3_value** argv); + void writeDisconnectInfoTable(STKPeer* peer); + void initServerStatsTable(); + bool writeReport(STKPeer* reporter, std::shared_ptr reporter_npp, + STKPeer* reporting, std::shared_ptr reporting_npp, + irr::core::stringw& info); + bool hasDatabase() const { return m_db != nullptr; } + bool hasServerStatsTable() const { return !m_server_stats_table.empty(); } + bool hasPlayerReportsTable() const + { return m_player_reports_table_exists; } + bool hasIpBanTable() const { return m_ip_ban_table_exists; } + bool hasIpv6BanTable() const { return m_ipv6_ban_table_exists; } + bool hasOnlineIdBanTable() const { return m_online_id_ban_table_exists; } + bool isTimeToPoll() const + { return StkTime::getMonoTimeMs() >= m_last_poll_db_time + 60000; } + void updatePollTime() { m_last_poll_db_time = StkTime::getMonoTimeMs(); } + std::vector getIpBanTableData(uint32_t ip = 0) const; + std::vector getIpv6BanTableData(std::string ipv6 = "") const; + std::vector getOnlineIdBanTableData(uint32_t online_id = 0) const; + void increaseIpBanTriggerCount(uint32_t ip_start, uint32_t ip_end) const; + void increaseIpv6BanTriggerCount(const std::string& ipv6_cidr) const; + void increaseOnlineIdBanTriggerCount(uint32_t online_id) const; + void clearOldReports(); + void setDisconnectionTimes(std::vector& present_hosts); + void saveAddressToIpBanTable(const SocketAddress& addr); + void onPlayerJoinQueries(std::shared_ptr peer, uint32_t online_id, + unsigned player_count, const std::string& country_code); + void listBanTable(); +}; + + + +#endif // ifndef DATABASE_CONNECTOR_HPP +#endif // ifdef ENABLE_SQLITE3 diff --git a/src/network/protocols/server_lobby.cpp b/src/network/protocols/server_lobby.cpp index 70283270d3d..12092553663 100644 --- a/src/network/protocols/server_lobby.cpp +++ b/src/network/protocols/server_lobby.cpp @@ -30,6 +30,7 @@ #include "modes/capture_the_flag.hpp" #include "modes/linear_world.hpp" #include "network/crypto.hpp" +#include "network/database_connector.hpp" #include "network/event.hpp" #include "network/game_setup.hpp" #include "network/network.hpp" @@ -102,69 +103,6 @@ class SubmitRankingRequest : public Online::XMLRequest // We use max priority for all server requests to avoid downloading of addons // icons blocking the poll request in all-in-one graphical client server -#ifdef ENABLE_SQLITE3 - -// ---------------------------------------------------------------------------- -static void upperIPv6SQL(sqlite3_context* context, int argc, - sqlite3_value** argv) -{ - if (argc != 1) - { - sqlite3_result_int64(context, 0); - return; - } - - char* ipv6 = (char*)sqlite3_value_text(argv[0]); - if (ipv6 == NULL) - { - sqlite3_result_int64(context, 0); - return; - } - sqlite3_result_int64(context, upperIPv6(ipv6)); -} - -// ---------------------------------------------------------------------------- -void insideIPv6CIDRSQL(sqlite3_context* context, int argc, - sqlite3_value** argv) -{ - if (argc != 2) - { - sqlite3_result_int(context, 0); - return; - } - - char* ipv6_cidr = (char*)sqlite3_value_text(argv[0]); - char* ipv6_in = (char*)sqlite3_value_text(argv[1]); - if (ipv6_cidr == NULL || ipv6_in == NULL) - { - sqlite3_result_int(context, 0); - return; - } - sqlite3_result_int(context, insideIPv6CIDR(ipv6_cidr, ipv6_in)); -} // insideIPv6CIDRSQL - -// ---------------------------------------------------------------------------- -/* -Copy below code so it can be use as loadable extension to be used in sqlite3 -command interface (together with andIPv6 and insideIPv6CIDR from stk_ipv6) - -#include "sqlite3ext.h" -SQLITE_EXTENSION_INIT1 -// ---------------------------------------------------------------------------- -sqlite3_extension_init(sqlite3* db, char** pzErrMsg, - const sqlite3_api_routines* pApi) -{ - SQLITE_EXTENSION_INIT2(pApi) - sqlite3_create_function(db, "insideIPv6CIDR", 2, SQLITE_UTF8, NULL, - insideIPv6CIDRSQL, NULL, NULL); - sqlite3_create_function(db, "upperIPv6", 1, SQLITE_UTF8, 0, upperIPv6SQL, - 0, 0); - return 0; -} // sqlite3_extension_init -*/ - -#endif - /** This is the central game setup protocol running in the server. It is * mostly a finite state machine. Note that all nodes in ellipses and light * grey background are actual states; nodes in boxes and white background @@ -248,8 +186,11 @@ ServerLobby::ServerLobby() : LobbyProtocol() m_difficulty.store(ServerConfig::m_server_difficulty); m_game_mode.store(ServerConfig::m_server_mode); m_default_vote = new PeerVote(); - m_player_reports_table_exists = false; - initDatabase(); + +#ifdef ENABLE_SQLITE3 + m_db_connector = new DatabaseConnector(); + m_db_connector->initDatabase(); +#endif } // ServerLobby //----------------------------------------------------------------------------- @@ -267,302 +208,22 @@ ServerLobby::~ServerLobby() if (m_save_server_config) ServerConfig::writeServerConfigToDisk(); delete m_default_vote; - destroyDatabase(); -} // ~ServerLobby -//----------------------------------------------------------------------------- -void ServerLobby::initDatabase() -{ #ifdef ENABLE_SQLITE3 - m_last_poll_db_time = StkTime::getMonoTimeMs(); - m_db = NULL; - m_ip_ban_table_exists = false; - m_ipv6_ban_table_exists = false; - m_online_id_ban_table_exists = false; - m_ip_geolocation_table_exists = false; - m_ipv6_geolocation_table_exists = false; - if (!ServerConfig::m_sql_management) - return; - const std::string& path = ServerConfig::getConfigDirectory() + "/" + - ServerConfig::m_database_file.c_str(); - int ret = sqlite3_open_v2(path.c_str(), &m_db, - SQLITE_OPEN_SHAREDCACHE | SQLITE_OPEN_FULLMUTEX | - SQLITE_OPEN_READWRITE, NULL); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", "Cannot open database: %s.", - sqlite3_errmsg(m_db)); - sqlite3_close(m_db); - m_db = NULL; - return; - } - sqlite3_busy_handler(m_db, [](void* data, int retry) - { - int retry_count = ServerConfig::m_database_timeout / 100; - if (retry < retry_count) - { - sqlite3_sleep(100); - // Return non-zero to let caller retry again - return 1; - } - // Return zero to let caller return SQLITE_BUSY immediately - return 0; - }, NULL); - sqlite3_create_function(m_db, "insideIPv6CIDR", 2, SQLITE_UTF8, NULL, - &insideIPv6CIDRSQL, NULL, NULL); - sqlite3_create_function(m_db, "upperIPv6", 1, SQLITE_UTF8, NULL, - &upperIPv6SQL, NULL, NULL); - checkTableExists(ServerConfig::m_ip_ban_table, m_ip_ban_table_exists); - checkTableExists(ServerConfig::m_ipv6_ban_table, m_ipv6_ban_table_exists); - checkTableExists(ServerConfig::m_online_id_ban_table, - m_online_id_ban_table_exists); - checkTableExists(ServerConfig::m_player_reports_table, - m_player_reports_table_exists); - checkTableExists(ServerConfig::m_ip_geolocation_table, - m_ip_geolocation_table_exists); - checkTableExists(ServerConfig::m_ipv6_geolocation_table, - m_ipv6_geolocation_table_exists); + m_db_connector->destroyDatabase(); + delete m_db_connector; #endif -} // initDatabase +} // ~ServerLobby //----------------------------------------------------------------------------- + void ServerLobby::initServerStatsTable() { #ifdef ENABLE_SQLITE3 - if (!ServerConfig::m_sql_management || !m_db) - return; - std::string table_name = std::string("v") + - StringUtils::toString(ServerConfig::m_server_db_version) + "_" + - ServerConfig::m_server_uid + "_stats"; - - std::ostringstream oss; - oss << "CREATE TABLE IF NOT EXISTS " << table_name << " (\n" - " host_id INTEGER UNSIGNED NOT NULL PRIMARY KEY, -- Unique host id in STKHost of each connection session for a STKPeer\n" - " ip INTEGER UNSIGNED NOT NULL, -- IP decimal of host\n"; - if (ServerConfig::m_ipv6_connection) - oss << " ipv6 TEXT NOT NULL DEFAULT '', -- IPv6 (if exists) in string of host\n"; - oss << " port INTEGER UNSIGNED NOT NULL, -- Port of host\n" - " online_id INTEGER UNSIGNED NOT NULL, -- Online if of the host (0 for offline account)\n" - " username TEXT NOT NULL, -- First player name in the host (if the host has splitscreen player)\n" - " player_num INTEGER UNSIGNED NOT NULL, -- Number of player(s) from the host, more than 1 if it has splitscreen player\n" - " country_code TEXT NULL DEFAULT NULL, -- 2-letter country code of the host\n" - " version TEXT NOT NULL, -- SuperTuxKart version of the host\n" - " os TEXT NOT NULL, -- Operating system of the host\n" - " connected_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Time when connected\n" - " disconnected_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Time when disconnected (saved when disconnected)\n" - " ping INTEGER UNSIGNED NOT NULL DEFAULT 0, -- Ping of the host\n" - " packet_loss INTEGER NOT NULL DEFAULT 0 -- Mean packet loss count from ENet (saved when disconnected)\n" - ") WITHOUT ROWID;"; - std::string query = oss.str(); - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - ret = sqlite3_finalize(stmt); - if (ret == SQLITE_OK) - m_server_stats_table = table_name; - else - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - if (m_server_stats_table.empty()) - return; - - // Extra default table _countries: - // Server owner need to initialise this table himself, check NETWORKING.md - std::string country_table_name = std::string("v") + StringUtils::toString( - ServerConfig::m_server_db_version) + "_countries"; - query = StringUtils::insertValues( - "CREATE TABLE IF NOT EXISTS %s (\n" - " country_code TEXT NOT NULL PRIMARY KEY UNIQUE, -- Unique 2-letter country code\n" - " country_flag TEXT NOT NULL, -- Unicode country flag representation of 2-letter country code\n" - " country_name TEXT NOT NULL -- Readable name of this country\n" - ") WITHOUT ROWID;", country_table_name.c_str()); - easySQLQuery(query); - - // Default views: - // _full_stats - // Full stats with ip in human readable format and time played of each - // players in minutes - std::string full_stats_view_name = std::string("v") + - StringUtils::toString(ServerConfig::m_server_db_version) + "_" + - ServerConfig::m_server_uid + "_full_stats"; - oss.str(""); - oss << "CREATE VIEW IF NOT EXISTS " << full_stats_view_name << " AS\n" - << " SELECT host_id, ip,\n" - << " ((ip >> 24) & 255) ||'.'|| ((ip >> 16) & 255) ||'.'|| ((ip >> 8) & 255) ||'.'|| ((ip ) & 255) AS ip_readable,\n"; - if (ServerConfig::m_ipv6_connection) - oss << " ipv6,"; - oss << " port, online_id, username, player_num,\n" - << " " << m_server_stats_table << ".country_code AS country_code, country_flag, country_name, version, os,\n" - << " ROUND((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0, 2) AS time_played,\n" - << " connected_time, disconnected_time, ping, packet_loss FROM " << m_server_stats_table << "\n" - << " LEFT JOIN " << country_table_name << " ON " - << country_table_name << ".country_code = " << m_server_stats_table << ".country_code\n" - << " ORDER BY connected_time DESC;"; - query = oss.str(); - easySQLQuery(query); - - // _current_players - // Current players in server with ip in human readable format and time - // played of each players in minutes - std::string current_players_view_name = std::string("v") + - StringUtils::toString(ServerConfig::m_server_db_version) + "_" + - ServerConfig::m_server_uid + "_current_players"; - oss.str(""); - oss.clear(); - oss << "CREATE VIEW IF NOT EXISTS " << current_players_view_name << " AS\n" - << " SELECT host_id, ip,\n" - << " ((ip >> 24) & 255) ||'.'|| ((ip >> 16) & 255) ||'.'|| ((ip >> 8) & 255) ||'.'|| ((ip ) & 255) AS ip_readable,\n"; - if (ServerConfig::m_ipv6_connection) - oss << " ipv6,"; - oss << " port, online_id, username, player_num,\n" - << " " << m_server_stats_table << ".country_code AS country_code, country_flag, country_name, version, os,\n" - << " ROUND((STRFTIME(\"%s\", 'now') - STRFTIME(\"%s\", connected_time)) / 60.0, 2) AS time_played,\n" - << " connected_time, ping FROM " << m_server_stats_table << "\n" - << " LEFT JOIN " << country_table_name << " ON " - << country_table_name << ".country_code = " << m_server_stats_table << ".country_code\n" - << " WHERE connected_time = disconnected_time;"; - query = oss.str(); - easySQLQuery(query); - - // _player_stats - // All players with online id and username with their time played stats - // in this server since creation of this database - // If sqlite supports window functions (since 3.25), it will include last session player info (ip, country, ping...) - std::string player_stats_view_name = std::string("v") + - StringUtils::toString(ServerConfig::m_server_db_version) + "_" + - ServerConfig::m_server_uid + "_player_stats"; - oss.str(""); - oss.clear(); - if (sqlite3_libversion_number() < 3025000) - { - oss << "CREATE VIEW IF NOT EXISTS " << player_stats_view_name << " AS\n" - << " SELECT online_id, username, COUNT(online_id) AS num_connections,\n" - << " MIN(connected_time) AS first_connected_time,\n" - << " MAX(connected_time) AS last_connected_time,\n" - << " ROUND(SUM((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS total_time_played,\n" - << " ROUND(AVG((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS average_time_played,\n" - << " ROUND(MIN((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS min_time_played,\n" - << " ROUND(MAX((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS max_time_played\n" - << " FROM " << m_server_stats_table << "\n" - << " WHERE online_id != 0 GROUP BY online_id ORDER BY num_connections DESC;"; - } - else - { - oss << "CREATE VIEW IF NOT EXISTS " << player_stats_view_name << " AS\n" - << " SELECT a.online_id, a.username, a.ip, a.ip_readable,\n"; - if (ServerConfig::m_ipv6_connection) - oss << " a.ipv6,"; - oss << " a.port, a.player_num,\n" - << " a.country_code, a.country_flag, a.country_name, a.version, a.os, a.ping, a.packet_loss,\n" - << " b.num_connections, b.first_connected_time, b.first_disconnected_time,\n" - << " a.connected_time AS last_connected_time, a.disconnected_time AS last_disconnected_time,\n" - << " a.time_played AS last_time_played, b.total_time_played, b.average_time_played,\n" - << " b.min_time_played, b.max_time_played\n" - << " FROM\n" - << " (\n" - << " SELECT *,\n" - << " ROW_NUMBER() OVER\n" - << " (\n" - << " PARTITION BY online_id\n" - << " ORDER BY connected_time DESC\n" - << " ) RowNum\n" - << " FROM " << full_stats_view_name << " where online_id != 0\n" - << " ) as a\n" - << " JOIN\n" - << " (\n" - << " SELECT online_id, COUNT(online_id) AS num_connections,\n" - << " MIN(connected_time) AS first_connected_time,\n" - << " MIN(disconnected_time) AS first_disconnected_time,\n" - << " ROUND(SUM((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS total_time_played,\n" - << " ROUND(AVG((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS average_time_played,\n" - << " ROUND(MIN((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS min_time_played,\n" - << " ROUND(MAX((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS max_time_played\n" - << " FROM " << m_server_stats_table << " WHERE online_id != 0 GROUP BY online_id\n" - << " ) AS b\n" - << " ON b.online_id = a.online_id\n" - << " WHERE RowNum = 1 ORDER BY num_connections DESC;\n"; - } - query = oss.str(); - easySQLQuery(query); - - uint32_t last_host_id = 0; - query = StringUtils::insertValues("SELECT MAX(host_id) FROM %s;", - m_server_stats_table.c_str()); - ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW && sqlite3_column_type(stmt, 0) != SQLITE_NULL) - { - last_host_id = (unsigned)sqlite3_column_int64(stmt, 0); - Log::info("ServerLobby", "%u was last server session max host id.", - last_host_id); - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - m_server_stats_table = ""; - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - m_server_stats_table = ""; - } - STKHost::get()->setNextHostId(last_host_id); - - // Update disconnected time (if stk crashed it will not be written) - query = StringUtils::insertValues( - "UPDATE %s SET disconnected_time = datetime('now') " - "WHERE connected_time = disconnected_time;", - m_server_stats_table.c_str()); - easySQLQuery(query); + m_db_connector->initServerStatsTable(); #endif } // initServerStatsTable -//----------------------------------------------------------------------------- -void ServerLobby::destroyDatabase() -{ -#ifdef ENABLE_SQLITE3 - auto peers = STKHost::get()->getPeers(); - for (auto& peer : peers) - writeDisconnectInfoTable(peer.get()); - if (m_db != NULL) - sqlite3_close(m_db); -#endif -} // destroyDatabase - -//----------------------------------------------------------------------------- -void ServerLobby::writeDisconnectInfoTable(STKPeer* peer) -{ -#ifdef ENABLE_SQLITE3 - if (m_server_stats_table.empty()) - return; - std::string query = StringUtils::insertValues( - "UPDATE %s SET disconnected_time = datetime('now'), " - "ping = %d, packet_loss = %d " - "WHERE host_id = %u;", m_server_stats_table.c_str(), - peer->getAveragePing(), peer->getPacketLoss(), - peer->getHostId()); - easySQLQuery(query); -#endif -} // writeDisconnectInfoTable - //----------------------------------------------------------------------------- void ServerLobby::updateAddons() { @@ -964,346 +625,104 @@ bool ServerLobby::notifyEventAsynchronous(Event* event) */ void ServerLobby::pollDatabase() { - if (!ServerConfig::m_sql_management || !m_db) + if (!ServerConfig::m_sql_management || !m_db_connector->hasDatabase()) return; - if (StkTime::getMonoTimeMs() < m_last_poll_db_time + 60000) + if (!m_db_connector->isTimeToPoll()) return; - m_last_poll_db_time = StkTime::getMonoTimeMs(); + m_db_connector->updatePollTime(); - if (m_ip_ban_table_exists) + std::vector ip_ban_list = + m_db_connector->getIpBanTableData(); + std::vector ipv6_ban_list = + m_db_connector->getIpv6BanTableData(); + std::vector online_id_ban_list = + m_db_connector->getOnlineIdBanTableData(); + + for (std::shared_ptr& p : STKHost::get()->getPeers()) { - std::string query = - "SELECT ip_start, ip_end, reason, description FROM "; - query += ServerConfig::m_ip_ban_table; - query += " WHERE datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now'));"; - auto peers = STKHost::get()->getPeers(); - sqlite3_exec(m_db, query.c_str(), - [](void* ptr, int count, char** data, char** columns) + if (p->isAIPeer()) + continue; + bool is_kicked = false; + std::string address = ""; + std::string reason = ""; + std::string description = ""; + + if (p->getAddress().isIPv6()) + { + address = p->getAddress().toString(false); + if (address.empty()) + continue; + for (auto& item: ipv6_ban_list) { - std::vector >* peers_ptr = - (std::vector >*)ptr; - for (std::shared_ptr& p : *peers_ptr) + if (insideIPv6CIDR(item.ipv6_cidr.c_str(), address.c_str()) == 1) { - // IPv4 ban list atm - if (p->isAIPeer() || p->getAddress().isIPv6()) - continue; - - uint32_t ip_start = 0; - uint32_t ip_end = 0; - if (!StringUtils::fromString(data[0], ip_start)) - continue; - if (!StringUtils::fromString(data[1], ip_end)) - continue; - uint32_t peer_addr = p->getAddress().getIP(); - if (ip_start <= peer_addr && ip_end >= peer_addr) - { - Log::info("ServerLobby", - "Kick %s, reason: %s, description: %s", - p->getAddress().toString().c_str(), - data[2], data[3]); - p->kick(); - } + is_kicked = true; + reason = item.reason; + description = item.description; + break; } - return 0; - }, &peers, NULL); - } - - if (m_ipv6_ban_table_exists) - { - std::string query = - "SELECT ipv6_cidr, reason, description FROM "; - query += ServerConfig::m_ipv6_ban_table; - query += " WHERE datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now'));"; - auto peers = STKHost::get()->getPeers(); - sqlite3_exec(m_db, query.c_str(), - [](void* ptr, int count, char** data, char** columns) + } + } + else + { + uint32_t peer_addr = p->getAddress().getIP(); + address = p->getAddress().toString(); + for (auto& item: ip_ban_list) { - std::vector >* peers_ptr = - (std::vector >*)ptr; - for (std::shared_ptr& p : *peers_ptr) + if (item.ip_start <= peer_addr && item.ip_end >= peer_addr) { - std::string ipv6; - if (p->getAddress().isIPv6()) - ipv6 = p->getAddress().toString(false); - // IPv6 ban list atm - if (p->isAIPeer() || ipv6.empty()) - continue; - - char* ipv6_cidr = data[0]; - if (insideIPv6CIDR(ipv6_cidr, ipv6.c_str()) == 1) - { - Log::info("ServerLobby", - "Kick %s, reason: %s, description: %s", - ipv6.c_str(), data[1], data[2]); - p->kick(); - } + is_kicked = true; + reason = item.reason; + description = item.description; + break; } - return 0; - }, &peers, NULL); - } - - if (m_online_id_ban_table_exists) - { - std::string query = "SELECT online_id, reason, description FROM "; - query += ServerConfig::m_online_id_ban_table; - query += " WHERE datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now'));"; - auto peers = STKHost::get()->getPeers(); - sqlite3_exec(m_db, query.c_str(), - [](void* ptr, int count, char** data, char** columns) + } + } + if (!is_kicked && !p->getPlayerProfiles().empty()) + { + uint32_t online_id = p->getPlayerProfiles()[0]->getOnlineId(); + for (auto& item: online_id_ban_list) { - std::vector >* peers_ptr = - (std::vector >*)ptr; - for (std::shared_ptr& p : *peers_ptr) + if (item.online_id == online_id) { - if (p->isAIPeer() - || p->getPlayerProfiles().empty()) - continue; - - uint32_t online_id = 0; - if (!StringUtils::fromString(data[0], online_id)) - continue; - if (online_id == p->getPlayerProfiles()[0]->getOnlineId()) - { - Log::info("ServerLobby", - "Kick %s, reason: %s, description: %s", - p->getAddress().toString().c_str(), - data[1], data[2]); - p->kick(); - } + is_kicked = true; + reason = item.reason; + description = item.description; + break; } - return 0; - }, &peers, NULL); - } + } + } + if (is_kicked) + { + Log::info("ServerLobby", "Kick %s, reason: %s, description: %s", + address.c_str(), reason.c_str(), description.c_str()); + p->kick(); + } + } // for p in peers - if (m_player_reports_table_exists && - ServerConfig::m_player_reports_expired_days != 0.0f) - { - std::string query = StringUtils::insertValues( - "DELETE FROM %s " - "WHERE datetime" - "(reported_time, '+%f days') < datetime('now');", - ServerConfig::m_player_reports_table.c_str(), - ServerConfig::m_player_reports_expired_days); - easySQLQuery(query); - } - if (m_server_stats_table.empty()) - return; + m_db_connector->clearOldReports(); - std::string query; auto peers = STKHost::get()->getPeers(); - std::vector exist_hosts; + std::vector hosts; if (!peers.empty()) { for (auto& peer : peers) { if (!peer->isValidated()) continue; - exist_hosts.push_back(peer->getHostId()); + hosts.push_back(peer->getHostId()); } } - if (peers.empty() || exist_hosts.empty()) - { - query = StringUtils::insertValues( - "UPDATE %s SET disconnected_time = datetime('now') " - "WHERE connected_time = disconnected_time;", - m_server_stats_table.c_str()); - } - else - { - std::ostringstream oss; - oss << "UPDATE " << m_server_stats_table - << " SET disconnected_time = datetime('now')" - << " WHERE connected_time = disconnected_time AND" - << " host_id NOT IN ("; - for (unsigned i = 0; i < exist_hosts.size(); i++) - { - oss << exist_hosts[i]; - if (i != (exist_hosts.size() - 1)) - oss << ","; - } - oss << ");"; - query = oss.str(); - } - easySQLQuery(query); + m_db_connector->setDisconnectionTimes(hosts); } // pollDatabase - -//----------------------------------------------------------------------------- -/** Run simple query with write lock waiting and optional function, this - * function has no callback for the return (if any) by the query. - * Return true if no error occurs - */ -bool ServerLobby::easySQLQuery(const std::string& query, - std::function bind_function) const -{ - if (!m_db) - return false; - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - if (bind_function) - bind_function(stmt); - ret = sqlite3_step(stmt); - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for easy query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return false; - } - } - else - { - Log::error("ServerLobby", - "Error preparing database for easy query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return false; - } - return true; -} // easySQLQuery - -//----------------------------------------------------------------------------- -/* Write true to result if table name exists in database. */ -void ServerLobby::checkTableExists(const std::string& table, bool& result) -{ - if (!m_db) - return; - sqlite3_stmt* stmt = NULL; - if (!table.empty()) - { - std::string query = StringUtils::insertValues( - "SELECT count(type) FROM sqlite_master " - "WHERE type='table' AND name='%s';", table.c_str()); - - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - int number = sqlite3_column_int(stmt, 0); - if (number == 1) - { - Log::info("ServerLobby", "Table named %s will used.", - table.c_str()); - result = true; - } - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - } - if (!result && !table.empty()) - { - Log::warn("ServerLobby", "Table named %s not found in database.", - table.c_str()); - } -} // checkTableExists - -//----------------------------------------------------------------------------- -std::string ServerLobby::ip2Country(const SocketAddress& addr) const -{ - if (!m_db || !m_ip_geolocation_table_exists || addr.isLAN()) - return ""; - - std::string cc_code; - std::string query = StringUtils::insertValues( - "SELECT country_code FROM %s " - "WHERE `ip_start` <= %d AND `ip_end` >= %d " - "ORDER BY `ip_start` DESC LIMIT 1;", - ServerConfig::m_ip_geolocation_table.c_str(), addr.getIP(), - addr.getIP()); - - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - const char* country_code = (char*)sqlite3_column_text(stmt, 0); - cc_code = country_code; - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return ""; - } - return cc_code; -} // ip2Country - -//----------------------------------------------------------------------------- -std::string ServerLobby::ipv62Country(const SocketAddress& addr) const -{ - if (!m_db || !m_ipv6_geolocation_table_exists) - return ""; - - std::string cc_code; - const std::string& ipv6 = addr.toString(false/*show_port*/); - std::string query = StringUtils::insertValues( - "SELECT country_code FROM %s " - "WHERE `ip_start` <= upperIPv6(\"%s\") AND `ip_end` >= upperIPv6(\"%s\") " - "ORDER BY `ip_start` DESC LIMIT 1;", - ServerConfig::m_ipv6_geolocation_table.c_str(), ipv6.c_str(), - ipv6.c_str()); - - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - const char* country_code = (char*)sqlite3_column_text(stmt, 0); - cc_code = country_code; - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return ""; - } - return cc_code; -} // ipv62Country - #endif - //----------------------------------------------------------------------------- void ServerLobby::writePlayerReport(Event* event) { #ifdef ENABLE_SQLITE3 - if (!m_db || !m_player_reports_table_exists) + if (!m_db_connector->hasDatabase() || !m_db_connector->hasPlayerReportsTable()) return; STKPeer* reporter = event->getPeer(); if (!reporter->hasPlayerProfiles()) @@ -1321,65 +740,8 @@ void ServerLobby::writePlayerReport(Event* event) return; auto reporting_npp = reporting_peer->getPlayerProfiles()[0]; - std::string query; - if (ServerConfig::m_ipv6_connection) - { - query = StringUtils::insertValues( - "INSERT INTO %s " - "(server_uid, reporter_ip, reporter_ipv6, reporter_online_id, reporter_username, " - "info, reporting_ip, reporting_ipv6, reporting_online_id, reporting_username) " - "VALUES (?, %u, \"%s\", %u, ?, ?, %u, \"%s\", %u, ?);", - ServerConfig::m_player_reports_table.c_str(), - !reporter->getAddress().isIPv6() ? reporter->getAddress().getIP() : 0, - reporter->getAddress().isIPv6() ? reporter->getAddress().toString(false) : "", - reporter_npp->getOnlineId(), - !reporting_peer->getAddress().isIPv6() ? reporting_peer->getAddress().getIP() : 0, - reporting_peer->getAddress().isIPv6() ? reporting_peer->getAddress().toString(false) : "", - reporting_npp->getOnlineId()); - } - else - { - query = StringUtils::insertValues( - "INSERT INTO %s " - "(server_uid, reporter_ip, reporter_online_id, reporter_username, " - "info, reporting_ip, reporting_online_id, reporting_username) " - "VALUES (?, %u, %u, ?, ?, %u, %u, ?);", - ServerConfig::m_player_reports_table.c_str(), - reporter->getAddress().getIP(), reporter_npp->getOnlineId(), - reporting_peer->getAddress().getIP(), reporting_npp->getOnlineId()); - } - bool written = easySQLQuery(query, - [reporter_npp, reporting_npp, info](sqlite3_stmt* stmt) - { - // SQLITE_TRANSIENT to copy string - if (sqlite3_bind_text(stmt, 1, ServerConfig::m_server_uid.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - ServerConfig::m_server_uid.c_str()); - } - if (sqlite3_bind_text(stmt, 2, - StringUtils::wideToUtf8(reporter_npp->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(reporter_npp->getName()).c_str()); - } - if (sqlite3_bind_text(stmt, 3, - StringUtils::wideToUtf8(info).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(info).c_str()); - } - if (sqlite3_bind_text(stmt, 4, - StringUtils::wideToUtf8(reporting_npp->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(reporting_npp->getName()).c_str()); - } - }); + bool written = m_db_connector->writeReport(reporter, reporter_npp, + reporting_peer.get(), reporting_npp, info); if (written) { NetworkString* success = getNetworkString(); @@ -3359,7 +2721,9 @@ void ServerLobby::clientDisconnected(Event* event) updatePlayerList(); delete msg; - writeDisconnectInfoTable(event->getPeer()); +#ifdef ENABLE_SQLITE3 + m_db_connector->writeDisconnectInfoTable(event->getPeer()); +#endif } // clientDisconnected //----------------------------------------------------------------------------- @@ -3402,14 +2766,7 @@ void ServerLobby::kickPlayerWithReason(STKPeer* peer, const char* reason) const void ServerLobby::saveIPBanTable(const SocketAddress& addr) { #ifdef ENABLE_SQLITE3 - if (addr.isIPv6() || !m_db || !m_ip_ban_table_exists) - return; - - std::string query = StringUtils::insertValues( - "INSERT INTO %s (ip_start, ip_end) " - "VALUES (%u, %u);", - ServerConfig::m_ip_ban_table.c_str(), addr.getIP(), addr.getIP()); - easySQLQuery(query); + m_db_connector->saveAddressToIpBanTable(addr); #endif } // saveIPBanTable @@ -3754,9 +3111,9 @@ void ServerLobby::handleUnencryptedConnection(std::shared_ptr peer, #ifdef ENABLE_SQLITE3 if (country_code.empty() && !peer->getAddress().isIPv6()) - country_code = ip2Country(peer->getAddress()); + country_code = m_db_connector->ip2Country(peer->getAddress()); if (country_code.empty() && peer->getAddress().isIPv6()) - country_code = ipv62Country(peer->getAddress()); + country_code = m_db_connector->ipv62Country(peer->getAddress()); #endif auto red_blue = STKHost::get()->getAllPlayersTeamInfo(); @@ -3828,7 +3185,7 @@ void ServerLobby::handleUnencryptedConnection(std::shared_ptr peer, message_ack->addFloat(auto_start_timer) .addUInt32(ServerConfig::m_state_frequency) .addUInt8(ServerConfig::m_chat ? 1 : 0) - .addUInt8(m_player_reports_table_exists ? 1 : 0); + .addUInt8(playerReportsTableExists() ? 1 : 0); peer->setSpectator(false); @@ -3889,74 +3246,7 @@ void ServerLobby::handleUnencryptedConnection(std::shared_ptr peer, } #ifdef ENABLE_SQLITE3 - if (m_server_stats_table.empty() || peer->isAIPeer()) - return; - std::string query; - if (ServerConfig::m_ipv6_connection && peer->getAddress().isIPv6()) - { - query = StringUtils::insertValues( - "INSERT INTO %s " - "(host_id, ip, ipv6 ,port, online_id, username, player_num, " - "country_code, version, os, ping) " - "VALUES (%u, 0, \"%s\" ,%u, %u, ?, %u, ?, ?, ?, %u);", - m_server_stats_table.c_str(), peer->getHostId(), - peer->getAddress().toString(false), peer->getAddress().getPort(), - online_id, player_count, peer->getAveragePing()); - } - else - { - query = StringUtils::insertValues( - "INSERT INTO %s " - "(host_id, ip, port, online_id, username, player_num, " - "country_code, version, os, ping) " - "VALUES (%u, %u, %u, %u, ?, %u, ?, ?, ?, %u);", - m_server_stats_table.c_str(), peer->getHostId(), - peer->getAddress().getIP(), peer->getAddress().getPort(), - online_id, player_count, peer->getAveragePing()); - } - easySQLQuery(query, [peer, country_code](sqlite3_stmt* stmt) - { - if (sqlite3_bind_text(stmt, 1, StringUtils::wideToUtf8( - peer->getPlayerProfiles()[0]->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8( - peer->getPlayerProfiles()[0]->getName()).c_str()); - } - if (country_code.empty()) - { - if (sqlite3_bind_null(stmt, 2) != SQLITE_OK) - { - Log::error("easySQLQuery", - "Failed to bind NULL for country code."); - } - } - else - { - if (sqlite3_bind_text(stmt, 2, country_code.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind country: %s.", - country_code.c_str()); - } - } - auto version_os = - StringUtils::extractVersionOS(peer->getUserVersion()); - if (sqlite3_bind_text(stmt, 3, version_os.first.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - version_os.first.c_str()); - } - if (sqlite3_bind_text(stmt, 4, version_os.second.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - version_os.second.c_str()); - } - } - ); + m_db_connector->onPlayerJoinQueries(peer, online_id, player_count, country_code); #endif } // handleUnencryptedConnection @@ -4796,66 +4086,34 @@ void ServerLobby::resetServer() void ServerLobby::testBannedForIP(STKPeer* peer) const { #ifdef ENABLE_SQLITE3 - if (!m_db || !m_ip_ban_table_exists) + if (!m_db_connector->hasDatabase() || !m_db_connector->hasIpBanTable()) return; // Test for IPv4 if (peer->getAddress().isIPv6()) return; - int row_id = -1; - unsigned ip_start = 0; - unsigned ip_end = 0; - std::string query = StringUtils::insertValues( - "SELECT rowid, ip_start, ip_end, reason, description FROM %s " - "WHERE ip_start <= %u AND ip_end >= %u " - "AND datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now')) " - "LIMIT 1;", - ServerConfig::m_ip_ban_table.c_str(), - peer->getAddress().getIP(), peer->getAddress().getIP()); - - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - row_id = sqlite3_column_int(stmt, 0); - ip_start = (unsigned)sqlite3_column_int64(stmt, 1); - ip_end = (unsigned)sqlite3_column_int64(stmt, 2); - const char* reason = (char*)sqlite3_column_text(stmt, 3); - const char* desc = (char*)sqlite3_column_text(stmt, 4); - Log::info("ServerLobby", "%s banned by IP: %s " + bool is_banned = false; + uint32_t ip_start = 0; + uint32_t ip_end = 0; + + std::vector ip_ban_list = + m_db_connector->getIpBanTableData(peer->getAddress().getIP()); + if (!ip_ban_list.empty()) + { + is_banned = true; + ip_start = ip_ban_list[0].ip_start; + ip_end = ip_ban_list[0].ip_end; + int row_id = ip_ban_list[0].row_id; + std::string reason = ip_ban_list[0].reason; + std::string description = ip_ban_list[0].description; + Log::info("ServerLobby", "%s banned by IP: %s " "(rowid: %d, description: %s).", - peer->getAddress().toString().c_str(), reason, row_id, desc); - kickPlayerWithReason(peer, reason); - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return; - } - if (row_id != -1) - { - query = StringUtils::insertValues( - "UPDATE %s SET trigger_count = trigger_count + 1, " - "last_trigger = datetime('now') " - "WHERE ip_start = %u AND ip_end = %u;", - ServerConfig::m_ip_ban_table.c_str(), ip_start, ip_end); - easySQLQuery(query); + peer->getAddress().toString().c_str(), reason.c_str(), row_id, description.c_str()); + kickPlayerWithReason(peer, reason.c_str()); } + if (is_banned) + m_db_connector->increaseIpBanTriggerCount(ip_start, ip_end); #endif } // testBannedForIP @@ -4863,78 +4121,33 @@ void ServerLobby::testBannedForIP(STKPeer* peer) const void ServerLobby::testBannedForIPv6(STKPeer* peer) const { #ifdef ENABLE_SQLITE3 - if (!m_db || !m_ipv6_ban_table_exists) + if (!m_db_connector->hasDatabase() || !m_db_connector->hasIpv6BanTable()) return; // Test for IPv6 if (!peer->getAddress().isIPv6()) return; - int row_id = -1; - std::string ipv6_cidr; - std::string query = StringUtils::insertValues( - "SELECT rowid, ipv6_cidr, reason, description FROM %s " - "WHERE insideIPv6CIDR(ipv6_cidr, ?) = 1 " - "AND datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now')) " - "LIMIT 1;", - ServerConfig::m_ipv6_ban_table.c_str()); - - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - if (sqlite3_bind_text(stmt, 1, - peer->getAddress().toString(false).c_str(), -1, SQLITE_TRANSIENT) - != SQLITE_OK) - { - Log::error("ServerLobby", "Error binding ipv6 addr for query: %s", - sqlite3_errmsg(m_db)); - } + bool is_banned = false; + std::string ipv6_cidr = ""; - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - row_id = sqlite3_column_int(stmt, 0); - ipv6_cidr = (char*)sqlite3_column_text(stmt, 1); - const char* reason = (char*)sqlite3_column_text(stmt, 2); - const char* desc = (char*)sqlite3_column_text(stmt, 3); - Log::info("ServerLobby", "%s banned by IP: %s " - "(rowid: %d, description: %s).", - peer->getAddress().toString().c_str(), reason, row_id, desc); - kickPlayerWithReason(peer, reason); - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return; - } - if (row_id != -1) + std::vector ipv6_ban_list = + m_db_connector->getIpv6BanTableData(peer->getAddress().toString(false)); + + if (!ipv6_ban_list.empty()) { - query = StringUtils::insertValues( - "UPDATE %s SET trigger_count = trigger_count + 1, " - "last_trigger = datetime('now') " - "WHERE ipv6_cidr = ?;", ServerConfig::m_ipv6_ban_table.c_str()); - easySQLQuery(query, [ipv6_cidr](sqlite3_stmt* stmt) - { - if (sqlite3_bind_text(stmt, 1, ipv6_cidr.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - ipv6_cidr.c_str()); - } - }); + is_banned = true; + ipv6_cidr = ipv6_ban_list[0].ipv6_cidr; + int row_id = ipv6_ban_list[0].row_id; + std::string reason = ipv6_ban_list[0].reason; + std::string description = ipv6_ban_list[0].description; + Log::info("ServerLobby", "%s banned by IPv6: %s " + "(rowid: %d, description: %s).", + peer->getAddress().toString(false).c_str(), reason.c_str(), row_id, description.c_str()); + kickPlayerWithReason(peer, reason.c_str()); } + if (is_banned) + m_db_connector->increaseIpv6BanTriggerCount(ipv6_cidr); #endif } // testBannedForIPv6 @@ -4943,57 +4156,26 @@ void ServerLobby::testBannedForOnlineId(STKPeer* peer, uint32_t online_id) const { #ifdef ENABLE_SQLITE3 - if (!m_db || !m_online_id_ban_table_exists) + if (!m_db_connector->hasDatabase() || !m_db_connector->hasOnlineIdBanTable()) return; - int row_id = -1; - std::string query = StringUtils::insertValues( - "SELECT rowid, reason, description FROM %s " - "WHERE online_id = %u " - "AND datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now')) " - "LIMIT 1;", - ServerConfig::m_online_id_ban_table.c_str(), online_id); - - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - row_id = sqlite3_column_int(stmt, 0); - const char* reason = (char*)sqlite3_column_text(stmt, 1); - const char* desc = (char*)sqlite3_column_text(stmt, 2); - Log::info("ServerLobby", "%s banned by online id: %s " - "(online id: %u rowid: %d, description: %s).", - peer->getAddress().toString().c_str(), reason, online_id, - row_id, desc); - kickPlayerWithReason(peer, reason); - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", "Error finalize database: %s", - sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database: %s", - sqlite3_errmsg(m_db)); - return; - } - if (row_id != -1) + bool is_banned = false; + std::vector online_id_ban_list = + m_db_connector->getOnlineIdBanTableData(online_id); + + if (!online_id_ban_list.empty()) { - query = StringUtils::insertValues( - "UPDATE %s SET trigger_count = trigger_count + 1, " - "last_trigger = datetime('now') " - "WHERE online_id = %u;", - ServerConfig::m_online_id_ban_table.c_str(), online_id); - easySQLQuery(query); + is_banned = true; + int row_id = online_id_ban_list[0].row_id; + std::string reason = online_id_ban_list[0].reason; + std::string description = online_id_ban_list[0].description; + Log::info("ServerLobby", "%s banned by online id: %s " + "(online id: %u, rowid: %d, description: %s).", + peer->getAddress().toString().c_str(), reason.c_str(), online_id, row_id, description.c_str()); + kickPlayerWithReason(peer, reason.c_str()); } + if (is_banned) + m_db_connector->increaseOnlineIdBanTriggerCount(online_id); #endif } // testBannedForOnlineId @@ -5001,34 +4183,7 @@ void ServerLobby::testBannedForOnlineId(STKPeer* peer, void ServerLobby::listBanTable() { #ifdef ENABLE_SQLITE3 - if (!m_db) - return; - auto printer = [](void* data, int argc, char** argv, char** name) - { - for (int i = 0; i < argc; i++) - { - std::cout << name[i] << " = " << (argv[i] ? argv[i] : "NULL") - << "\n"; - } - std::cout << "\n"; - return 0; - }; - if (m_ip_ban_table_exists) - { - std::string query = "SELECT * FROM "; - query += ServerConfig::m_ip_ban_table; - query += ";"; - std::cout << "IP ban list:\n"; - sqlite3_exec(m_db, query.c_str(), printer, NULL, NULL); - } - if (m_online_id_ban_table_exists) - { - std::string query = "SELECT * FROM "; - query += ServerConfig::m_online_id_ban_table; - query += ";"; - std::cout << "Online Id ban list:\n"; - sqlite3_exec(m_db, query.c_str(), printer, NULL, NULL); - } + m_db_connector->listBanTable(); #endif } // listBanTable @@ -6052,3 +5207,13 @@ void ServerLobby::handleServerCommand(Event* event, delete chat; } } // handleServerCommand + +//----------------------------------------------------------------------------- +bool ServerLobby::playerReportsTableExists() const +{ +#ifdef ENABLE_SQLITE3 + return m_db_connector->hasPlayerReportsTable(); +#else + return false; +#endif +} diff --git a/src/network/protocols/server_lobby.hpp b/src/network/protocols/server_lobby.hpp index a54f9f901bb..a8e9245add7 100644 --- a/src/network/protocols/server_lobby.hpp +++ b/src/network/protocols/server_lobby.hpp @@ -34,11 +34,12 @@ #include #include -#ifdef ENABLE_SQLITE3 -#include -#endif +// #ifdef ENABLE_SQLITE3 +// #include +// #endif class BareNetworkString; +class DatabaseConnector; class NetworkItemManager; class NetworkString; class NetworkPlayerProfile; @@ -78,39 +79,12 @@ class ServerLobby : public LobbyProtocol std::string m_country_code; bool m_tried = false; }; - bool m_player_reports_table_exists; #ifdef ENABLE_SQLITE3 - sqlite3* m_db; - - std::string m_server_stats_table; - - bool m_ip_ban_table_exists; - - bool m_ipv6_ban_table_exists; - - bool m_online_id_ban_table_exists; - - bool m_ip_geolocation_table_exists; - - bool m_ipv6_geolocation_table_exists; - - uint64_t m_last_poll_db_time; + DatabaseConnector* m_db_connector; void pollDatabase(); - - bool easySQLQuery(const std::string& query, - std::function bind_function = nullptr) const; - - void checkTableExists(const std::string& table, bool& result); - - std::string ip2Country(const SocketAddress& addr) const; - - std::string ipv62Country(const SocketAddress& addr) const; #endif - void initDatabase(); - - void destroyDatabase(); std::atomic m_state; @@ -378,7 +352,6 @@ class ServerLobby : public LobbyProtocol void testBannedForIP(STKPeer* peer) const; void testBannedForIPv6(STKPeer* peer) const; void testBannedForOnlineId(STKPeer* peer, uint32_t online_id) const; - void writeDisconnectInfoTable(STKPeer* peer); void writePlayerReport(Event* event); bool supportsAI(); void updateAddons(); @@ -419,6 +392,7 @@ class ServerLobby : public LobbyProtocol uint32_t getServerIdOnline() const { return m_server_id_online; } void setClientServerHostId(uint32_t id) { m_client_server_host_id = id; } static int m_fixed_laps; + bool playerReportsTableExists() const; }; // class ServerLobby #endif // SERVER_LOBBY_HPP From 9fb7448abe0687bfc5b7f5120362521d915bfa22 Mon Sep 17 00:00:00 2001 From: kimden <23140380+kimden@users.noreply.github.com> Date: Wed, 3 Jul 2024 23:13:40 +0400 Subject: [PATCH 2/2] Get rid of repeated code, add comments easySQLQuery was generalized to return output rows if requested, which allowed to shorten the code for many queries. The code for binding values to sqlite statement was also repeated many times. Two auxiliary structures were introduced, so that it's possible to provide at the same time both those parameters which require and those which don't require binding, in a single StringUtils::insertValues() call. --- src/network/database_connector.cpp | 582 +++++++++++++------------ src/network/database_connector.hpp | 100 ++++- src/network/protocols/server_lobby.hpp | 4 - 3 files changed, 391 insertions(+), 295 deletions(-) diff --git a/src/network/database_connector.cpp b/src/network/database_connector.cpp index 93844c897c7..30155495b05 100644 --- a/src/network/database_connector.cpp +++ b/src/network/database_connector.cpp @@ -1,6 +1,6 @@ // // SuperTuxKart - a fun racing game with go-kart -// Copyright (C) 2013-2015 SuperTuxKart-Team +// Copyright (C) 2024 SuperTuxKart-Team // // This program is free software; you can redistribute it and/or // modify it under the terms of the GNU General Public License @@ -29,6 +29,61 @@ #include "utils/log.hpp" //----------------------------------------------------------------------------- +/** Prints "?" to the output stream and saves the Binder object to the + * corresponding BinderCollection so that it can produce bind function later + * When we invoke StringUtils::insertValues with a Binder argument, the + * implementation of insertValues ensures that this function is invoked for + * all Binder arguments from left to right. + */ +std::ostream& operator << (std::ostream& os, const Binder& binder) +{ + os << "?"; + binder.m_collection.lock()->m_binders.emplace_back(std::make_shared(binder)); + return os; +} // operator << (Binder) + +//----------------------------------------------------------------------------- +/** Returns a bind function that should be used inside an easySQLQuery. As the + * Binder objects are already ordered in a correct way, the indices just go + * from 1 upwards. Depending on a particular Binder, we can also bind NULL + * instead of a string. + */ +std::function BinderCollection::getBindFunction() const +{ + auto binders = m_binders; + return [binders](sqlite3_stmt* stmt) + { + int idx = 1; + for (std::shared_ptr binder: binders) + { + if (binder) + { + // SQLITE_TRANSIENT to copy string + if (binder->m_use_null_if_empty && binder->m_value.empty()) + { + if (sqlite3_bind_null(stmt, idx) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind NULL for %s.", + binder->m_name.c_str()); + } + } + else + { + if (sqlite3_bind_text(stmt, idx, binder->m_value.c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind %s as %s.", + binder->m_value.c_str(), binder->m_name.c_str()); + } + } + } + ++idx; + } + }; +} // BinderCollection::getBindFunction + +//----------------------------------------------------------------------------- +/** Opens the database, sets its busy handler and variables related to it. */ void DatabaseConnector::initDatabase() { m_last_poll_db_time = StkTime::getMonoTimeMs(); @@ -83,6 +138,7 @@ void DatabaseConnector::initDatabase() } // initDatabase //----------------------------------------------------------------------------- +/** Closes the database. */ void DatabaseConnector::destroyDatabase() { auto peers = STKHost::get()->getPeers(); @@ -93,12 +149,18 @@ void DatabaseConnector::destroyDatabase() } // destroyDatabase //----------------------------------------------------------------------------- -/** Run simple query with write lock waiting and optional function, this - * function has no callback for the return (if any) by the query. - * Return true if no error occurs +/** Runs simple query with optional bind function. If output vector pointer is + * not (default) nullptr, then the output is written there. + * \param query The SQL query with '?'-placeholders for values to bind. + * \param output The 2D vector for output rows. If nullptr, the query output + * is ignored. + * \param bind_function The function for binding missing values. + * \return True if no error occurs. */ -bool DatabaseConnector::easySQLQuery(const std::string& query, - std::function bind_function) const +bool DatabaseConnector::easySQLQuery( + const std::string& query, std::vector>* output, + std::function bind_function, + std::string null_value) const { if (!m_db) return false; @@ -109,6 +171,24 @@ bool DatabaseConnector::easySQLQuery(const std::string& query, if (bind_function) bind_function(stmt); ret = sqlite3_step(stmt); + if (output) + { + output->clear(); + while (ret == SQLITE_ROW) + { + output->emplace_back(); + int columns = sqlite3_column_count(stmt); + for (int i = 0; i < columns; ++i) + { + const char* value = (char*)sqlite3_column_text(stmt, i); + if (value == nullptr) + output->back().push_back(null_value); + else + output->back().push_back(std::string(value)); + } + ret = sqlite3_step(stmt); + } + } ret = sqlite3_finalize(stmt); if (ret != SQLITE_OK) { @@ -129,38 +209,30 @@ bool DatabaseConnector::easySQLQuery(const std::string& query, } // easySQLQuery //----------------------------------------------------------------------------- -/* Write true to result if table name exists in database. */ +/** Performs a query to determine if a certain table exists. + * \param table The searched name. + * \param result The output value. + */ void DatabaseConnector::checkTableExists(const std::string& table, bool& result) { if (!m_db) return; - sqlite3_stmt* stmt = NULL; + result = false; if (!table.empty()) { std::string query = StringUtils::insertValues( "SELECT count(type) FROM sqlite_master " "WHERE type='table' AND name='%s';", table.c_str()); - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) + std::vector> output; + if (easySQLQuery(query, &output) && !output.empty()) { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - int number = sqlite3_column_int(stmt, 0); - if (number == 1) - { - Log::info("DatabaseConnector", "Table named %s will be used.", - table.c_str()); - result = true; - } - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) + int number; + if (StringUtils::fromString(output[0][0], number) && number == 1) { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); + Log::info("DatabaseConnector", "Table named %s will be used.", + table.c_str()); + result = true; } } } @@ -172,6 +244,12 @@ void DatabaseConnector::checkTableExists(const std::string& table, bool& result) } // checkTableExists //----------------------------------------------------------------------------- +/** Queries the database's IP mapping to determine the country code for an + * address. + * \param addr Queried address. + * \return A country code string if the address is found in the mapping, + * and an empty string otherwise. + */ std::string DatabaseConnector::ip2Country(const SocketAddress& addr) const { if (!m_db || !m_ip_geolocation_table_exists || addr.isLAN()) @@ -185,34 +263,21 @@ std::string DatabaseConnector::ip2Country(const SocketAddress& addr) const ServerConfig::m_ip_geolocation_table.c_str(), addr.getIP(), addr.getIP()); - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - const char* country_code = (char*)sqlite3_column_text(stmt, 0); - cc_code = country_code; - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else + std::vector> output; + if (easySQLQuery(query, &output) && !output.empty()) { - Log::error("DatabaseConnector", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return ""; + cc_code = output[0][0]; } return cc_code; } // ip2Country //----------------------------------------------------------------------------- +/** Queries the database's IPv6 mapping to determine the country code for an + * address. + * \param addr Queried address. + * \return A country code string if the address is found in the mapping, + * and an empty string otherwise. + */ std::string DatabaseConnector::ipv62Country(const SocketAddress& addr) const { if (!m_db || !m_ipv6_geolocation_table_exists) @@ -227,34 +292,16 @@ std::string DatabaseConnector::ipv62Country(const SocketAddress& addr) const ServerConfig::m_ipv6_geolocation_table.c_str(), ipv6.c_str(), ipv6.c_str()); - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - const char* country_code = (char*)sqlite3_column_text(stmt, 0); - cc_code = country_code; - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else + std::vector> output; + if (easySQLQuery(query, &output) && !output.empty()) { - Log::error("DatabaseConnector", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return ""; + cc_code = output[0][0]; } return cc_code; } // ipv62Country // ---------------------------------------------------------------------------- +/** A function invoked within SQLite */ void DatabaseConnector::upperIPv6SQL(sqlite3_context* context, int argc, sqlite3_value** argv) { @@ -274,6 +321,9 @@ void DatabaseConnector::upperIPv6SQL(sqlite3_context* context, int argc, } // ---------------------------------------------------------------------------- +/** A function that checks within SQLite whether an IPv6 address (argv[1]) + * is located within a specified block (argv[0]) of IPv6 addresses. + */ void DatabaseConnector::insideIPv6CIDRSQL(sqlite3_context* context, int argc, sqlite3_value** argv) { @@ -314,6 +364,10 @@ sqlite3_extension_init(sqlite3* db, char** pzErrMsg, */ //----------------------------------------------------------------------------- +/** When a peer disconnects from the server, this function saves to the + * database peer's disconnection time and statistics (ping and packet loss). + * \param peer Disconnecting peer. + */ void DatabaseConnector::writeDisconnectInfoTable(STKPeer* peer) { if (m_server_stats_table.empty()) @@ -328,7 +382,11 @@ void DatabaseConnector::writeDisconnectInfoTable(STKPeer* peer) } // writeDisconnectInfoTable //----------------------------------------------------------------------------- - +/** Creates necessary tables and views if they don't exist yet in the database. + * As the function is invoked during the server launch, it also updates rows + * related to players whose disconnection time wasn't written, and loads + * last used host id. + */ void DatabaseConnector::initServerStatsTable() { if (!ServerConfig::m_sql_management || !m_db) @@ -356,26 +414,10 @@ void DatabaseConnector::initServerStatsTable() " packet_loss INTEGER NOT NULL DEFAULT 0 -- Mean packet loss count from ENet (saved when disconnected)\n" ") WITHOUT ROWID;"; std::string query = oss.str(); - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - ret = sqlite3_finalize(stmt); - if (ret == SQLITE_OK) - m_server_stats_table = table_name; - else - { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("DatabaseConnector", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } + + if (easySQLQuery(query)) + m_server_stats_table = table_name; + if (m_server_stats_table.empty()) return; @@ -501,31 +543,22 @@ void DatabaseConnector::initServerStatsTable() uint32_t last_host_id = 0; query = StringUtils::insertValues("SELECT MAX(host_id) FROM %s;", m_server_stats_table.c_str()); - ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) + + std::vector> output; + if (easySQLQuery(query, &output)) { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW && sqlite3_column_type(stmt, 0) != SQLITE_NULL) + if (!output.empty() && !output[0].empty() + && StringUtils::fromString(output[0][0], last_host_id)) { - last_host_id = (unsigned)sqlite3_column_int64(stmt, 0); Log::info("DatabaseConnector", "%u was last server session max host id.", last_host_id); } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - m_server_stats_table = ""; - } } else { - Log::error("DatabaseConnector", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); m_server_stats_table = ""; } + STKHost::get()->setNextHostId(last_host_id); // Update disconnected time (if stk crashed it will not be written) @@ -537,25 +570,41 @@ void DatabaseConnector::initServerStatsTable() } // initServerStatsTable //----------------------------------------------------------------------------- -bool DatabaseConnector::writeReport(STKPeer* reporter, std::shared_ptr reporter_npp, - STKPeer* reporting, std::shared_ptr reporting_npp, - irr::core::stringw& info) +/** Writes a report of one player about another player. + * \param reporter Peer that sends the report. + * \param reporter_npp Player profile that sends the report. + * \param reporting Peer that is reported. + * \param reporting_npp Player profile that is reported. + * \param info The report message. + * \return True if the database query succeeded. + */ +bool DatabaseConnector::writeReport( + STKPeer* reporter, std::shared_ptr reporter_npp, + STKPeer* reporting, std::shared_ptr reporting_npp, + irr::core::stringw& info) { std::string query; + + std::shared_ptr coll = std::make_shared(); if (ServerConfig::m_ipv6_connection) { query = StringUtils::insertValues( "INSERT INTO %s " "(server_uid, reporter_ip, reporter_ipv6, reporter_online_id, reporter_username, " "info, reporting_ip, reporting_ipv6, reporting_online_id, reporting_username) " - "VALUES (?, %u, \"%s\", %u, ?, ?, %u, \"%s\", %u, ?);", + "VALUES (%s, %u, \"%s\", %u, %s, %s, %u, \"%s\", %u, %s);", ServerConfig::m_player_reports_table.c_str(), + Binder(coll, ServerConfig::m_server_uid, "server_uid"), !reporter->getAddress().isIPv6() ? reporter->getAddress().getIP() : 0, reporter->getAddress().isIPv6() ? reporter->getAddress().toString(false) : "", reporter_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporter_npp->getName()), "reporter_name"), + Binder(coll, StringUtils::wideToUtf8(info), "info"), !reporting->getAddress().isIPv6() ? reporting->getAddress().getIP() : 0, reporting->getAddress().isIPv6() ? reporting->getAddress().toString(false) : "", - reporting_npp->getOnlineId()); + reporting_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporting_npp->getName()), "reporting_name") + ); } else { @@ -563,46 +612,29 @@ bool DatabaseConnector::writeReport(STKPeer* reporter, std::shared_ptrgetAddress().getIP(), reporter_npp->getOnlineId(), - reporting->getAddress().getIP(), reporting_npp->getOnlineId()); - } - return easySQLQuery(query, - [reporter_npp, reporting_npp, info](sqlite3_stmt* stmt) - { - // SQLITE_TRANSIENT to copy string - if (sqlite3_bind_text(stmt, 1, ServerConfig::m_server_uid.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - ServerConfig::m_server_uid.c_str()); - } - if (sqlite3_bind_text(stmt, 2, - StringUtils::wideToUtf8(reporter_npp->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(reporter_npp->getName()).c_str()); - } - if (sqlite3_bind_text(stmt, 3, - StringUtils::wideToUtf8(info).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(info).c_str()); - } - if (sqlite3_bind_text(stmt, 4, - StringUtils::wideToUtf8(reporting_npp->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(reporting_npp->getName()).c_str()); - } - }); + Binder(coll, ServerConfig::m_server_uid, "server_uid"), + reporter->getAddress().getIP(), + reporter_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporter_npp->getName()), "reporter_name"), + Binder(coll, StringUtils::wideToUtf8(info), "info"), + reporting->getAddress().getIP(), + reporting_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporting_npp->getName()), "reporting_name") + ); + } + return easySQLQuery(query, nullptr, coll->getBindFunction()); } // writeReport //----------------------------------------------------------------------------- +/** Gets the rows from IPv4 ban table, either all of them (for polling + * purposes), or those describing a certain address (if only one peer has to + * be checked). + * \param ip The IP address to check the database for. If zero, all rows + * will be given. + * \return A vector of rows in the form of IpBanTableData structures. + */ std::vector DatabaseConnector::getIpBanTableData(uint32_t ip) const { @@ -624,26 +656,32 @@ DatabaseConnector::getIpBanTableData(uint32_t ip) const oss << " LIMIT 1"; oss << ";"; std::string query = oss.str(); - sqlite3_exec(m_db, query.c_str(), - [](void* ptr, int count, char** data, char** columns) - { - std::vector* vec = (std::vector*)ptr; - IpBanTableData element; - if (!StringUtils::fromString(data[0], element.row_id)) - return 0; - if (!StringUtils::fromString(data[1], element.ip_start)) - return 0; - if (!StringUtils::fromString(data[2], element.ip_end)) - return 0; - element.reason = std::string(data[3]); - element.description = std::string(data[4]); - vec->push_back(element); - return 0; - }, &result, NULL); + + std::vector> output; + easySQLQuery(query, &output); + + for (std::vector& row: output) + { + IpBanTableData element; + if (!StringUtils::fromString(row[0], element.row_id)) + continue; + if (!StringUtils::fromString(row[1], element.ip_start)) + continue; + if (!StringUtils::fromString(row[2], element.ip_end)) + continue; + element.reason = row[3]; + element.description = row[4]; + result.push_back(element); + } return result; } // getIpBanTableData //----------------------------------------------------------------------------- +/** For a peer that turned out to be banned by IPv4, this function increases + * the trigger count. + * \param ip_start Start of IP ban range corresponding to peer. + * \param ip_end End of IP ban range corresponding to peer. + */ void DatabaseConnector::increaseIpBanTriggerCount(uint32_t ip_start, uint32_t ip_end) const { std::string query = StringUtils::insertValues( @@ -655,6 +693,13 @@ void DatabaseConnector::increaseIpBanTriggerCount(uint32_t ip_start, uint32_t ip } // getIpBanTableData //----------------------------------------------------------------------------- +/** Gets the rows from IPv6 ban table, either all of them (for polling + * purposes), or those describing a certain address (if only one peer has to + * be checked). + * \param ip The IPv6 address to check the database for. If empty, all rows + * will be given. + * \return A vector of rows in the form of Ipv6BanTableData structures. + */ std::vector DatabaseConnector::getIpv6BanTableData(std::string ipv6) const { @@ -664,88 +709,68 @@ DatabaseConnector::getIpv6BanTableData(std::string ipv6) const return result; } bool single_ip = !ipv6.empty(); - std::ostringstream oss; - oss << "SELECT rowid, ipv6_cidr, reason, description FROM "; - oss << (std::string)ServerConfig::m_ipv6_ban_table; - oss << " WHERE "; + std::string query; + std::shared_ptr coll = std::make_shared(); + + query = StringUtils::insertValues( + "SELECT rowid, ipv6_cidr, reason, description FROM %s WHERE ", + ServerConfig::m_ipv6_ban_table.c_str() + ); if (single_ip) - oss << "insideIPv6CIDR(ipv6_cidr, ?) = 1 AND "; - oss << "datetime('now') > datetime(starting_time) AND " + query += StringUtils::insertValues( + "insideIPv6CIDR(ipv6_cidr, %s) = 1 AND ", + Binder(coll, ipv6, "ipv6") + ); + + query += "datetime('now') > datetime(starting_time) AND " "(expired_days is NULL OR datetime" "(starting_time, '+'||expired_days||' days') > datetime('now'))"; + if (single_ip) - oss << " LIMIT 1"; - oss << ";"; - std::string query = oss.str(); + query += " LIMIT 1;"; - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - if (single_ip) - { - if (sqlite3_bind_text(stmt, 1, - ipv6.c_str(), -1, SQLITE_TRANSIENT) - != SQLITE_OK) - { - Log::error("DatabaseConnector", "Error binding ipv6 addr for query: %s", - sqlite3_errmsg(m_db)); - return result; - } - } - ret = sqlite3_step(stmt); - while (ret == SQLITE_ROW) - { - const char* rowid_cstr = (char*)sqlite3_column_text(stmt, 0); - const char* ipv6cidr_cstr = (char*)sqlite3_column_text(stmt, 1); - const char* reason_cstr = (char*)sqlite3_column_text(stmt, 2); - const char* description_cstr = (char*)sqlite3_column_text(stmt, 3); - Ipv6BanTableData element; - if (StringUtils::fromString(rowid_cstr, element.row_id)) - { - element.ipv6_cidr = std::string(ipv6cidr_cstr); - element.reason = std::string(reason_cstr); - element.description = std::string(description_cstr); - result.push_back(element); - } - ret = sqlite3_step(stmt); - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("DatabaseConnector", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else + std::vector> output; + easySQLQuery(query, &output, coll->getBindFunction()); + + for (std::vector& row: output) { - Log::error("DatabaseConnector", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return result; + Ipv6BanTableData element; + if (!StringUtils::fromString(row[0], element.row_id)) + continue; + element.ipv6_cidr = row[1]; + element.reason = row[2]; + element.description = row[3]; + result.push_back(element); } return result; } // getIpv6BanTableData //----------------------------------------------------------------------------- +/** For a peer that turned out to be banned by IPv6, this function increases + * the trigger count. + * \param ipv6_cidr Block of IPv6 addresses corresponding to the peer. + */ void DatabaseConnector::increaseIpv6BanTriggerCount(const std::string& ipv6_cidr) const { + std::shared_ptr coll = std::make_shared(); std::string query = StringUtils::insertValues( "UPDATE %s SET trigger_count = trigger_count + 1, " "last_trigger = datetime('now') " - "WHERE ipv6_cidr = ?;", ServerConfig::m_ipv6_ban_table.c_str()); - easySQLQuery(query, [ipv6_cidr](sqlite3_stmt* stmt) - { - if (sqlite3_bind_text(stmt, 1, ipv6_cidr.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - ipv6_cidr.c_str()); - } - }); + "WHERE ipv6_cidr = %s;", + ServerConfig::m_ipv6_ban_table.c_str(), + Binder(coll, ipv6_cidr, "ipv6_cidr") + ); + easySQLQuery(query, nullptr, coll->getBindFunction()); } // increaseIpv6BanTriggerCount //----------------------------------------------------------------------------- +/** Gets the rows from online id ban table, either all of them (for polling + * purposes), or those describing a certain online id (if only one peer has + * to be checked). + * \param online_id The online id to check the database for. If empty, all + * rows will be given. + * \return A vector of rows in the form of OnlineIdBanTableData structures. + */ std::vector DatabaseConnector::getOnlineIdBanTableData(uint32_t online_id) const { @@ -786,6 +811,10 @@ DatabaseConnector::getOnlineIdBanTableData(uint32_t online_id) const } // getOnlineIdBanTableData //----------------------------------------------------------------------------- +/** For a peer that turned out to be banned by online id, this function + * increases the trigger count. + * \param online_id Online id of the peer. + */ void DatabaseConnector::increaseOnlineIdBanTriggerCount(uint32_t online_id) const { std::string query = StringUtils::insertValues( @@ -797,6 +826,9 @@ void DatabaseConnector::increaseOnlineIdBanTriggerCount(uint32_t online_id) cons } // increaseOnlineIdBanTriggerCount //----------------------------------------------------------------------------- +/** Clears reports that are older than a certain number of days + * (specified in the server config). + */ void DatabaseConnector::clearOldReports() { if (m_player_reports_table_exists && @@ -813,6 +845,10 @@ void DatabaseConnector::clearOldReports() } // clearOldReports //----------------------------------------------------------------------------- +/** Sets disconnection times for those peers that already left the server, but + * whose disconnection times wasn't set yet. + * \param present_hosts List of online ids of present peers. + */ void DatabaseConnector::setDisconnectionTimes(std::vector& present_hosts) { if (!hasServerStatsTable()) @@ -841,6 +877,10 @@ void DatabaseConnector::setDisconnectionTimes(std::vector& present_hos } // setDisconnectionTimes //----------------------------------------------------------------------------- +/** Adds a specified IP address to the IPv4 ban table. Usually invoked from + * network console. + * \param addr Address to ban. + */ void DatabaseConnector::saveAddressToIpBanTable(const SocketAddress& addr) { if (addr.isIPv6() || !m_db || !m_ip_ban_table_exists) @@ -854,22 +894,39 @@ void DatabaseConnector::saveAddressToIpBanTable(const SocketAddress& addr) } // saveAddressToIpBanTable //----------------------------------------------------------------------------- +/** Called when the player joins the server, inserts player info into database. + * \param peer The peer that joins. + * \param online_id Player's online id. + * \param player_count Number of players joining using a single peer. + * \param country_code Country code deduced by global or local IP mapping. + */ void DatabaseConnector::onPlayerJoinQueries(std::shared_ptr peer, uint32_t online_id, unsigned player_count, const std::string& country_code) { if (m_server_stats_table.empty() || peer->isAIPeer()) return; std::string query; + std::shared_ptr coll = std::make_shared(); + auto version_os = StringUtils::extractVersionOS(peer->getUserVersion()); if (ServerConfig::m_ipv6_connection && peer->getAddress().isIPv6()) { query = StringUtils::insertValues( "INSERT INTO %s " - "(host_id, ip, ipv6 ,port, online_id, username, player_num, " + "(host_id, ip, ipv6, port, online_id, username, player_num, " "country_code, version, os, ping) " - "VALUES (%u, 0, \"%s\" ,%u, %u, ?, %u, ?, ?, ?, %u);", - m_server_stats_table.c_str(), peer->getHostId(), - peer->getAddress().toString(false), peer->getAddress().getPort(), - online_id, player_count, peer->getAveragePing()); + "VALUES (%u, 0, \"%s\", %u, %u, %s, %u, %s, %s, %s, %u);", + m_server_stats_table.c_str(), + peer->getHostId(), + peer->getAddress().toString(false), + peer->getAddress().getPort(), + online_id, + Binder(coll, StringUtils::wideToUtf8(peer->getPlayerProfiles()[0]->getName()), "player_name"), + player_count, + Binder(coll, country_code, "country_code", true), + Binder(coll, version_os.first, "version"), + Binder(coll, version_os.second, "os"), + peer->getAveragePing() + ); } else { @@ -877,56 +934,25 @@ void DatabaseConnector::onPlayerJoinQueries(std::shared_ptr peer, "INSERT INTO %s " "(host_id, ip, port, online_id, username, player_num, " "country_code, version, os, ping) " - "VALUES (%u, %u, %u, %u, ?, %u, ?, ?, ?, %u);", - m_server_stats_table.c_str(), peer->getHostId(), - peer->getAddress().getIP(), peer->getAddress().getPort(), - online_id, player_count, peer->getAveragePing()); - } - easySQLQuery(query, [peer, country_code](sqlite3_stmt* stmt) - { - if (sqlite3_bind_text(stmt, 1, StringUtils::wideToUtf8( - peer->getPlayerProfiles()[0]->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8( - peer->getPlayerProfiles()[0]->getName()).c_str()); - } - if (country_code.empty()) - { - if (sqlite3_bind_null(stmt, 2) != SQLITE_OK) - { - Log::error("easySQLQuery", - "Failed to bind NULL for country code."); - } - } - else - { - if (sqlite3_bind_text(stmt, 2, country_code.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind country: %s.", - country_code.c_str()); - } - } - auto version_os = - StringUtils::extractVersionOS(peer->getUserVersion()); - if (sqlite3_bind_text(stmt, 3, version_os.first.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - version_os.first.c_str()); - } - if (sqlite3_bind_text(stmt, 4, version_os.second.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - version_os.second.c_str()); - } - }); + "VALUES (%u, %u, %u, %u, %s, %u, %s, %s, %s, %u);", + m_server_stats_table.c_str(), + peer->getHostId(), + peer->getAddress().getIP(), + peer->getAddress().getPort(), + online_id, + Binder(coll, StringUtils::wideToUtf8(peer->getPlayerProfiles()[0]->getName()), "player_name"), + player_count, + Binder(coll, country_code, "country_code", true), + Binder(coll, version_os.first, "version"), + Binder(coll, version_os.second, "os"), + peer->getAveragePing() + ); + } + easySQLQuery(query, nullptr, coll->getBindFunction()); } // onPlayerJoinQueries //----------------------------------------------------------------------------- +/** Prints all rows of the IPv4 ban table. Called from the network console. */ void DatabaseConnector::listBanTable() { if (!m_db) diff --git a/src/network/database_connector.hpp b/src/network/database_connector.hpp index 282fc859863..d22e588b8bd 100644 --- a/src/network/database_connector.hpp +++ b/src/network/database_connector.hpp @@ -24,18 +24,86 @@ #include "utils/string_utils.hpp" #include "utils/time.hpp" -#include -#include #include +#include #include -#include #include +#include +#include class SocketAddress; class STKPeer; class NetworkPlayerProfile; +/** The purpose of Binder and BinderCollection structures is to allow + * putting values to bind inside StringUtils::insertValues, which is commonly + * used for values that don't require binding (such as integers). + * Unlike previously used approach with separate query formation and binding, + * the arguments are written in the code in the order of query appearance + * (even though real binding still happens later). It also avoids repeated + * binding code. + * + * Syntax looks as follows: + * std::shared_ptr coll = std::make_shared...; + * std::string query_string = StringUtils::insertValues( + * "query contents with wildcards of type %d, %s, %u, ..." + * "where %s is put for values that will be bound later", + * values to insert, ..., Binder(coll, other parameters), ...); + * Then the bind function (e.g. for usage in easySQLQuery) should be + * coll->getBindFunction(). + */ + +struct Binder; + +/** BinderCollection is a structure that collects Binder objects used in an + * SQL query formed with insertValues() (see above). For a single query, a + * single instance of BinderCollection should be used. After a query is + * formed, BinderCollection can produce bind function to use with sqlite3. + */ +struct BinderCollection +{ + std::vector> m_binders; + + std::function getBindFunction() const; +}; +/** Binder is a wrapper for a string to be bound into an SQL query. See above + * for its usage in insertValues(). When it's printed to an output stream + * (in particular, this is done in insertValues implementation), this Binder + * is added to the query's BinderCollection, and the '?'-placeholder is added + * to the query string instead of %s. + * + * When using Binder, make sure that: + * - operator << is invoked on it exactly once; + * - operator << is invoked on several Binders in the order in which they go + * in the query; + * - before calling insertValues, there is a %-wildcard corresponding to the + * Binder in the query string (and not '?'). + * For example, when the query formed inside of a function depends on its + * arguments, it should be formed part by part, from left to right. + * Of course, you can choose the "default" way, binding values separately from + * insertValues() call. + */ +struct Binder +{ + std::weak_ptr m_collection; + std::string m_value; + std::string m_name; + bool m_use_null_if_empty; + + Binder(std::shared_ptr collection, std::string value, + std::string name = "", bool use_null_if_empty = false): + m_collection(collection), m_value(value), + m_name(name), m_use_null_if_empty(use_null_if_empty) {} +}; + +std::ostream& operator << (std::ostream& os, const Binder& binder); + +/** A class that manages the database operations needed for the server to work. + * The SQL queries are intended to be placed only within the implementation + * of this class, while the logic corresponding to those queries should not + * belong here. + */ class DatabaseConnector { private: @@ -50,6 +118,7 @@ class DatabaseConnector uint64_t m_last_poll_db_time; public: + /** Corresponds to the row of IPv4 ban table. */ struct IpBanTableData { int row_id; @@ -58,13 +127,17 @@ class DatabaseConnector std::string reason; std::string description; }; - struct Ipv6BanTableData { + /** Corresponds to the row of IPv6 ban table. */ + struct Ipv6BanTableData + { int row_id; std::string ipv6_cidr; std::string reason; std::string description; }; - struct OnlineIdBanTableData { + /** Corresponds to the row of online id ban table. */ + struct OnlineIdBanTableData + { int row_id; uint32_t online_id; std::string reason; @@ -74,7 +147,9 @@ class DatabaseConnector void destroyDatabase(); bool easySQLQuery(const std::string& query, - std::function bind_function = nullptr) const; + std::vector>* output = nullptr, + std::function bind_function = nullptr, + std::string null_value = "") const; void checkTableExists(const std::string& table, bool& result); @@ -83,14 +158,15 @@ class DatabaseConnector std::string ipv62Country(const SocketAddress& addr) const; static void upperIPv6SQL(sqlite3_context* context, int argc, - sqlite3_value** argv); + sqlite3_value** argv); static void insideIPv6CIDRSQL(sqlite3_context* context, int argc, - sqlite3_value** argv); + sqlite3_value** argv); void writeDisconnectInfoTable(STKPeer* peer); void initServerStatsTable(); - bool writeReport(STKPeer* reporter, std::shared_ptr reporter_npp, - STKPeer* reporting, std::shared_ptr reporting_npp, - irr::core::stringw& info); + bool writeReport( + STKPeer* reporter, std::shared_ptr reporter_npp, + STKPeer* reporting, std::shared_ptr reporting_npp, + irr::core::stringw& info); bool hasDatabase() const { return m_db != nullptr; } bool hasServerStatsTable() const { return !m_server_stats_table.empty(); } bool hasPlayerReportsTable() const @@ -115,7 +191,5 @@ class DatabaseConnector void listBanTable(); }; - - #endif // ifndef DATABASE_CONNECTOR_HPP #endif // ifdef ENABLE_SQLITE3 diff --git a/src/network/protocols/server_lobby.hpp b/src/network/protocols/server_lobby.hpp index a8e9245add7..fc18cfeca43 100644 --- a/src/network/protocols/server_lobby.hpp +++ b/src/network/protocols/server_lobby.hpp @@ -34,10 +34,6 @@ #include #include -// #ifdef ENABLE_SQLITE3 -// #include -// #endif - class BareNetworkString; class DatabaseConnector; class NetworkItemManager;