From babd18ef5ae8aab9f7e95907bac5d91d777a7685 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Thu, 7 Dec 2023 15:03:13 -0700 Subject: [PATCH] add custom median function to sqlite3 copy, sort, and index leaving quickselect implementation commented out because while there's currently no meaningful performance difference from sorting, future improvements might change that --- src/dbutils.c | 144 ++++++++++++++++++++++++++++++- test/unit/googletest/dbutils.cpp | 33 ++++++- 2 files changed, 174 insertions(+), 3 deletions(-) diff --git a/src/dbutils.c b/src/dbutils.c index bedc46a2b..78a9198d5 100644 --- a/src/dbutils.c +++ b/src/dbutils.c @@ -1078,6 +1078,146 @@ static void stdev_final(sqlite3_context *context) { sqlite3_result_double(context, sqrt(variance)); } +static void median_step(sqlite3_context *context, int argc, sqlite3_value **argv) { + (void) argc; + sll_t *data = (sll_t *) sqlite3_aggregate_context(context, sizeof(*data)); + if (sll_get_size(data) == 0) { + sll_init(data); + } + + const double value = sqlite3_value_double(argv[0]); + sll_push(data, (void *) (uintptr_t) value); +} + +/* /\* */ +/* * find kth largest element */ +/* * */ +/* * Adapted from code by Russell Cohen */ +/* * https://rcoh.me/posts/linear-time-median-finding/ */ +/* *\/ */ +/* static double quickselect(sll_t *sll, uint64_t count, uint64_t k) { */ +/* /\* cache unused values here since partitioning destroys the original list *\/ */ +/* sll_t cache; */ +/* sll_init(&cache); */ + +/* sll_t lt, eq, gt; */ +/* sll_init(<); */ +/* sll_init(&eq); */ +/* sll_init(>); */ + +/* while (count > 1) { */ +/* /\* TODO: Better pivot selection *\/ */ +/* const uint64_t pivot_idx = (rand() * rand()) % count; */ +/* double pivot = 0; */ +/* size_t i = 0; */ +/* sll_loop(sll, node) { */ +/* if (i == pivot_idx) { */ +/* pivot = (double) (uintptr_t) sll_node_data(node); */ +/* break; */ +/* } */ +/* i++; */ +/* } */ + +/* sll_node_t *node = NULL; */ +/* while ((node = sll_head_node(sll))) { */ +/* const double value = (double) (uint64_t) sll_node_data(node); */ +/* if (value < pivot) { */ +/* sll_move_append_first(<, sll, 1); */ +/* } */ +/* else if (value > pivot) { */ +/* sll_move_append_first(>, sll, 1); */ +/* } */ +/* else { */ +/* sll_move_append_first(&eq, sll, 1); */ +/* } */ +/* } */ + +/* /\* sll is empty at this point *\/ */ + +/* const uint64_t lt_size = sll_get_size(<); */ +/* const uint64_t eq_size = sll_get_size(&eq); */ + +/* if (k < lt_size) { */ +/* sll_move_append(sll, <); */ +/* sll_move_append(&cache, &eq); */ +/* sll_move_append(&cache, >); */ +/* } */ +/* else if (k < (lt_size + eq_size)) { */ +/* sll_move_append(&cache, <); */ +/* sll_move_append(sll, &eq); */ +/* sll_move_append(&cache, >); */ +/* break; */ +/* } */ +/* else { */ +/* k -= lt_size + eq_size; */ +/* sll_move_append(&cache, <); */ +/* sll_move_append(&cache, &eq); */ +/* sll_move_append(sll, >); */ +/* } */ + +/* count = sll_get_size(sll); */ +/* } */ + +/* /\* restore original list's contents (different order) *\/ */ +/* sll_move_append(sll, &cache); */ + +/* return (double) (uintptr_t) sll_node_data(sll_head_node(sll)); */ +/* } */ + +static int cmp_double(const void *lhs, const void *rhs) { + return * (double *) lhs - * (double *) rhs; +} + +static void median_final(sqlite3_context *context) { + sll_t *data = (sll_t *) sqlite3_aggregate_context(context, sizeof(*data)); + + const uint64_t count = sll_get_size(data); + double median = 0; + + /* skip some mallocs */ + if (count == 0) { + goto cleanup; + } + else if (count == 1) { + median = (double) (uintptr_t) sll_node_data(sll_head_node(data)); + goto ret_median; + } + else if (count == 2) { + median = ((double) (uintptr_t) sll_node_data(sll_head_node(data)) + + (double) (uintptr_t) sll_node_data(sll_tail_node(data))) / 2.0; + goto ret_median; + } + + const uint64_t half = count / 2; + + double *arr = malloc(count * sizeof(double)); + size_t i = 0; + sll_loop(data, node) { + arr[i++] = (double) (uintptr_t) sll_node_data(node); + } + + qsort(arr, count, sizeof(double), cmp_double); + + median = arr[half]; + if (!(count & 1)) { + median += arr[half - 1]; + median /= 2.0; + } + free(arr); + + /* median = quickselect(data, count, half); */ + /* if (!(count & 1)) { */ + /* median += quickselect(data, count, half - 1); */ + /* median /= 2.0; */ + /* } */ + + ret_median: + sqlite3_result_double(context, median); + + cleanup: + sll_destroy(data, NULL); +} + int addqueryfuncs_common(sqlite3 *db) { return !((sqlite3_create_function(db, "uidtouser", 1, SQLITE_UTF8, NULL, &uidtouser, NULL, NULL) == SQLITE_OK) && @@ -1094,7 +1234,9 @@ int addqueryfuncs_common(sqlite3 *db) { (sqlite3_create_function(db, "basename", 1, SQLITE_UTF8, NULL, &sqlite_basename, NULL, NULL) == SQLITE_OK) && (sqlite3_create_function(db, "stdev", 1, SQLITE_UTF8, - NULL, NULL, stdev_step, stdev_final) == SQLITE_OK)); + NULL, NULL, stdev_step, stdev_final) == SQLITE_OK) && + (sqlite3_create_function(db, "median", 1, SQLITE_UTF8, + NULL, NULL, median_step, median_final) == SQLITE_OK)); } int addqueryfuncs_with_context(sqlite3 *db, struct work *work) { diff --git a/test/unit/googletest/dbutils.cpp b/test/unit/googletest/dbutils.cpp index 7e44629d8..c2088baa6 100644 --- a/test/unit/googletest/dbutils.cpp +++ b/test/unit/googletest/dbutils.cpp @@ -577,7 +577,10 @@ TEST(addqueryfuncs, basename) { sqlite3_close(db); } -int double_callback(void *arg, int, char **data, char **) { +static int double_callback(void *arg, int, char **data, char **) { + if (!data[0]) { + return SQLITE_ERROR; + } return !(sscanf(data[0], "%lf", (double *) arg) == 1); } @@ -588,15 +591,41 @@ TEST(addqueryfuncs, stdev) { ASSERT_EQ(addqueryfuncs(db, 0, nullptr), 0); ASSERT_EQ(sqlite3_exec(db, "CREATE TABLE t (value INT);", nullptr, nullptr, nullptr), SQLITE_OK); - ASSERT_EQ(sqlite3_exec(db, "INSERT INTO t (value) VALUES (1), (2), (3), (4), (5);", nullptr, nullptr, nullptr), SQLITE_OK); double stdev = 0; + EXPECT_NE(sqlite3_exec(db, "SELECT stdev(value) FROM t", double_callback, &stdev, nullptr), SQLITE_OK); + + ASSERT_EQ(sqlite3_exec(db, "INSERT INTO t (value) VALUES (1), (2), (3), (4), (5);", nullptr, nullptr, nullptr), SQLITE_OK); + EXPECT_EQ(sqlite3_exec(db, "SELECT stdev(value) FROM t", double_callback, &stdev, nullptr), SQLITE_OK); EXPECT_DOUBLE_EQ(stdev * stdev * 2, (double) 5); /* sqrt(5 / 2) */ sqlite3_close(db); } +TEST(addqueryfuncs, median) { + sqlite3 *db = nullptr; + ASSERT_EQ(sqlite3_open(":memory:", &db), SQLITE_OK); + ASSERT_NE(db, nullptr); + + ASSERT_EQ(addqueryfuncs(db, 0, nullptr), 0); + ASSERT_EQ(sqlite3_exec(db, "CREATE TABLE t (value INT);", nullptr, nullptr, nullptr), SQLITE_OK); + + double median = 0; + EXPECT_NE(sqlite3_exec(db, "SELECT median(value) FROM t", double_callback, &median, nullptr), SQLITE_OK); + + ASSERT_EQ(sqlite3_exec(db, "INSERT INTO t (value) VALUES (1), (2), (3), (4), (5);", nullptr, nullptr, nullptr), SQLITE_OK); + + EXPECT_EQ(sqlite3_exec(db, "SELECT median(value) FROM t", double_callback, &median, nullptr), SQLITE_OK); + EXPECT_DOUBLE_EQ(median , (double) 3); + + ASSERT_EQ(sqlite3_exec(db, "INSERT INTO t (value) VALUES (6);", nullptr, nullptr, nullptr), SQLITE_OK); + EXPECT_EQ(sqlite3_exec(db, "SELECT median(value) FROM t", double_callback, &median, nullptr), SQLITE_OK); + EXPECT_DOUBLE_EQ(median , (double) 3.5); + + sqlite3_close(db); +} + TEST(sqlite_uri_path, none) { const char src[] = "prefix/basename"; size_t src_len = strlen(src);