Skip to content

Commit

Permalink
KNN support for spoints
Browse files Browse the repository at this point in the history
  • Loading branch information
waaeer authored and vitcpp committed Dec 11, 2023
1 parent 0be5757 commit c772c92
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ DATA_built = $(RELEASE_SQL) \
DOCS = README.pg_sphere COPYRIGHT.pg_sphere
TESTS = version tables points euler circle line ellipse poly path box \
index contains_ops contains_ops_compat bounding_box_gist gnomo \
epochprop contains overlaps spoint_brin sbox_brin selectivity
epochprop contains overlaps spoint_brin sbox_brin selectivity knn
REGRESS = init $(TESTS)

PG_CFLAGS += -DPGSPHERE_VERSION=$(PGSPHERE_VERSION)
Expand Down
13 changes: 12 additions & 1 deletion doc/indices.sgm
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@
</para>
</listitem>
</itemizedlist>
<para>
GiST index can be used also for fast finding points closest to the given one
when ordering by an expression with the <literal>&lt;-&gt;</literal> operator
is used, as shown in an example below.
</para>
<para>
BRIN indexing supports just spherical points (<type>spoint</type>)
and spherical coordinates range (<type>sbox</type>) at the moment.
Expand All @@ -82,6 +87,13 @@
<![CDATA[CREATE INDEX test_pos_idx ON test USING GIST (pos);]]>
<![CDATA[VACUUM ANALYZE test;]]>
</programlisting>
<para>
To find points closest to a given spherical position, use the <literal>&lt;-&gt;</literal> operator:
</para>
<programlisting>
<![CDATA[SELECT * FROM test ORDER BY pos <-> spoint (0.2, 0.3) LIMIT 10 ]]>
</programlisting>

<para>
BRIN index can be created through the following syntax:
</para>
Expand All @@ -100,7 +112,6 @@
<![CDATA[CREATE INDEX test_pos_idx USING BRIN ON test (pos) WITH (pages_per_range = 16);]]>
</programlisting>
</example>

</sect1>

<sect1 id="ind.smoc">
Expand Down
122 changes: 122 additions & 0 deletions expected/knn.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
CREATE TABLE points (id int, p spoint, pos int);
INSERT INTO points (id, p) SELECT x, spoint(random()*6.28, (2*random()-1)*1.57) FROM generate_series(1,314159) x;
CREATE INDEX i ON points USING gist (p);
SET enable_indexscan = true;
EXPLAIN (costs off) SELECT p <-> spoint (0.2, 0.3) FROM points ORDER BY 1 LIMIT 100;
QUERY PLAN
-------------------------------------------------
Limit
-> Index Scan using i on points
Order By: (p <-> '(0.2 , 0.3)'::spoint)
(3 rows)

UPDATE points SET pos = n FROM
(SELECT id, row_number() OVER (ORDER BY p <-> spoint (0.2, 0.3)) n FROM points ORDER BY p <-> spoint (0.2, 0.3) LIMIT 100) sel
WHERE points.id = sel.id;
SET enable_indexscan = false;
SELECT pos, row_number() OVER (ORDER BY p <-> spoint (0.2, 0.3)) n FROM points ORDER BY p <-> spoint (0.2, 0.3) LIMIT 100;
pos | n
-----+-----
1 | 1
2 | 2
3 | 3
4 | 4
5 | 5
6 | 6
7 | 7
8 | 8
9 | 9
10 | 10
11 | 11
12 | 12
13 | 13
14 | 14
15 | 15
16 | 16
17 | 17
18 | 18
19 | 19
20 | 20
21 | 21
22 | 22
23 | 23
24 | 24
25 | 25
26 | 26
27 | 27
28 | 28
29 | 29
30 | 30
31 | 31
32 | 32
33 | 33
34 | 34
35 | 35
36 | 36
37 | 37
38 | 38
39 | 39
40 | 40
41 | 41
42 | 42
43 | 43
44 | 44
45 | 45
46 | 46
47 | 47
48 | 48
49 | 49
50 | 50
51 | 51
52 | 52
53 | 53
54 | 54
55 | 55
56 | 56
57 | 57
58 | 58
59 | 59
60 | 60
61 | 61
62 | 62
63 | 63
64 | 64
65 | 65
66 | 66
67 | 67
68 | 68
69 | 69
70 | 70
71 | 71
72 | 72
73 | 73
74 | 74
75 | 75
76 | 76
77 | 77
78 | 78
79 | 79
80 | 80
81 | 81
82 | 82
83 | 83
84 | 84
85 | 85
86 | 86
87 | 87
88 | 88
89 | 89
90 | 90
91 | 91
92 | 92
93 | 93
94 | 94
95 | 95
96 | 96
97 | 97
98 | 98
99 | 99
100 | 100
(100 rows)

DROP TABLE points;
7 changes: 6 additions & 1 deletion pgs_gist.sql.in
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,15 @@ CREATE FUNCTION g_spoint_compress(internal)
AS 'MODULE_PATHNAME', 'g_spoint_compress'
LANGUAGE 'c';


CREATE FUNCTION g_spoint_consistent(internal, internal, int4, oid, internal)
RETURNS internal
AS 'MODULE_PATHNAME', 'g_spoint_consistent'
LANGUAGE 'c';

CREATE FUNCTION g_spoint_distance(internal, spoint, smallint, oid, internal)
RETURNS internal
AS 'MODULE_PATHNAME', 'g_spoint_distance'
LANGUAGE 'c';

CREATE OPERATOR CLASS spoint
DEFAULT FOR TYPE spoint USING gist AS
Expand All @@ -114,6 +117,7 @@ CREATE OPERATOR CLASS spoint
OPERATOR 14 @ (spoint, spoly),
OPERATOR 15 @ (spoint, sellipse),
OPERATOR 16 @ (spoint, sbox),
OPERATOR 17 <-> (spoint, spoint) FOR ORDER BY float_ops,
OPERATOR 37 <@ (spoint, scircle),
OPERATOR 38 <@ (spoint, sline),
OPERATOR 39 <@ (spoint, spath),
Expand All @@ -127,6 +131,7 @@ CREATE OPERATOR CLASS spoint
FUNCTION 5 g_spherekey_penalty (internal, internal, internal),
FUNCTION 6 g_spherekey_picksplit (internal, internal),
FUNCTION 7 g_spherekey_same (spherekey, spherekey, internal),
FUNCTION 8 g_spoint_distance (internal, spoint, smallint, oid, internal),
STORAGE spherekey;


Expand Down
12 changes: 12 additions & 0 deletions sql/knn.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
CREATE TABLE points (id int, p spoint, pos int);
INSERT INTO points (id, p) SELECT x, spoint(random()*6.28, (2*random()-1)*1.57) FROM generate_series(1,314159) x;
CREATE INDEX i ON points USING gist (p);
SET enable_indexscan = true;
EXPLAIN (costs off) SELECT p <-> spoint (0.2, 0.3) FROM points ORDER BY 1 LIMIT 100;
UPDATE points SET pos = n FROM
(SELECT id, row_number() OVER (ORDER BY p <-> spoint (0.2, 0.3)) n FROM points ORDER BY p <-> spoint (0.2, 0.3) LIMIT 100) sel
WHERE points.id = sel.id;
SET enable_indexscan = false;
SELECT pos, row_number() OVER (ORDER BY p <-> spoint (0.2, 0.3)) n FROM points ORDER BY p <-> spoint (0.2, 0.3) LIMIT 100;
DROP TABLE points;

126 changes: 126 additions & 0 deletions src/gist.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ PG_FUNCTION_INFO_V1(g_spherekey_penalty);
PG_FUNCTION_INFO_V1(g_spherekey_picksplit);
PG_FUNCTION_INFO_V1(g_spoint3_penalty);
PG_FUNCTION_INFO_V1(g_spoint3_picksplit);
PG_FUNCTION_INFO_V1(g_spoint_distance);
PG_FUNCTION_INFO_V1(g_spoint3_distance);
PG_FUNCTION_INFO_V1(g_spoint3_fetch);

Expand Down Expand Up @@ -681,6 +682,10 @@ g_spoint3_consistent(PG_FUNCTION_ARGS)
PG_RETURN_BOOL(false);
}

static double distance_vector_point_3d (Vector3D* v, double x, double y, double z) {
return acos ( (v->x * x + v->y * y + v->z * z) / sqrt( x*x + y*y + z*z ) ); // as v has length=1 by design
}

Datum
g_spoint3_distance(PG_FUNCTION_ARGS)
{
Expand Down Expand Up @@ -1672,6 +1677,127 @@ fallbackSplit(Box3D *boxes, OffsetNumber maxoff, GIST_SPLITVEC *v)
v->spl_ldatum_exists = v->spl_rdatum_exists = false;
}


Datum
g_spoint_distance(PG_FUNCTION_ARGS)
{
GISTENTRY *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
StrategyNumber strategy = (StrategyNumber) PG_GETARG_UINT16(2);
Box3D* box = (Box3D *) DatumGetPointer(entry->key);
double retval;
SPoint *point = (SPoint *) PG_GETARG_POINTER(1);
Vector3D v_point, v_low, v_high;

switch (strategy)
{
case 17:
// Prepare data for calculation
spoint_vector3d(&v_point, point);
v_low.x = (double)box->low.coord[0] / MAXCVALUE;
v_low.y = (double)box->low.coord[1] / MAXCVALUE;
v_low.z = (double)box->low.coord[2] / MAXCVALUE;
v_high.x = (double)box->high.coord[0] / MAXCVALUE;
v_high.y = (double)box->high.coord[1] / MAXCVALUE;
v_high.z = (double)box->high.coord[2] / MAXCVALUE;
// a box splits space into 27 subspaces (6+12+8+1) with different distance calculation
if(v_point.x < v_low.x) {
if(v_point.y < v_low.y) {
if(v_point.z < v_low.z) {
retval = distance_vector_point_3d (&v_point, v_low.x, v_low.y, v_low.z); //point2point distance
} else if (v_point.z < v_high.z) {
retval = distance_vector_point_3d (&v_point, v_low.x, v_low.y, v_point.z); //point2line distance
} else {
retval = distance_vector_point_3d (&v_point, v_low.x, v_low.y, v_high.z); //point2point distance
}
} else if(v_point.y < v_high.y) {
if(v_point.z < v_low.z) {
retval = distance_vector_point_3d (&v_point, v_low.x, v_point.y , v_low.z); //point2line distance
} else if (v_point.z < v_high.z) {
retval = distance_vector_point_3d (&v_point, v_low.x, v_point.y , v_point.z); //point2plane distance
} else {
retval = distance_vector_point_3d (&v_point, v_low.x, v_point.y, v_high.z); //point2line distance
}
} else {
if(v_point.z < v_low.z) {
retval = distance_vector_point_3d (&v_point, v_low.x, v_high.y, v_low.z); //point2point distance
} else if (v_point.z < v_high.z) {
retval = distance_vector_point_3d (&v_point, v_low.x, v_high.y, v_point.z); //point2line distance
} else {
retval = distance_vector_point_3d (&v_point, v_low.x, v_high.y, v_high.z); //point2point distance
}
}
} else if(v_point.x < v_high.x) {
if(v_point.y < v_low.y) {
if(v_point.z < v_low.z) {
retval = distance_vector_point_3d (&v_point, v_point.x, v_low.y, v_low.z); //p2line distance
} else if (v_point.z < v_high.z) {
retval = distance_vector_point_3d (&v_point, v_point.x, v_low.y, v_point.z); //point2plane distance
} else {
retval = distance_vector_point_3d (&v_point, v_point.x, v_low.y, v_high.z); //point2line distance
}
} else if(v_point.y < v_high.y) {
if(v_point.z < v_low.z) {
retval = distance_vector_point_3d (&v_point, v_point.x, v_point.y , v_low.z); //point2plane distance
} else if (v_point.z < v_high.z) {
retval = 0; // inside cube
} else {
retval = distance_vector_point_3d (&v_point, v_point.x, v_point.y, v_high.z); //point2plane distance
}
} else {
if(v_point.z < v_low.z) {
retval = distance_vector_point_3d (&v_point, v_point.x, v_high.y, v_low.z); //point2line distance
} else if (v_point.z < v_high.z) {
retval = distance_vector_point_3d (&v_point, v_point.x, v_high.y, v_point.z); //point2plane distance
} else {
retval = distance_vector_point_3d (&v_point, v_point.x, v_high.y, v_high.z); //point2line distance
}
}
} else {
if(v_point.y < v_low.y) {
if(v_point.z < v_low.z) {
retval = distance_vector_point_3d (&v_point, v_high.x, v_low.y, v_low.z); //p2p distance
} else if (v_point.z < v_high.z) {
retval = distance_vector_point_3d (&v_point, v_high.x, v_low.y, v_point.z); //point2line distance
} else {
retval = distance_vector_point_3d (&v_point, v_high.x, v_low.y, v_high.z); //point2point distance
}
} else if(v_point.y < v_high.y) {
if(v_point.z < v_low.z) {
retval = distance_vector_point_3d (&v_point, v_high.x, v_point.y , v_low.z); //point2line distance
} else if (v_point.z < v_high.z) {
retval = distance_vector_point_3d (&v_point, v_high.x, v_point.y , v_point.z); //point2plane distance
} else {
retval = distance_vector_point_3d (&v_point, v_high.x, v_point.y, v_high.z); //point2line distance
}
} else {
if(v_point.z < v_low.z) {
retval = distance_vector_point_3d (&v_point, v_high.x, v_high.y, v_low.z); //point2point distance
} else if (v_point.z < v_high.z) {
retval = distance_vector_point_3d (&v_point, v_high.x, v_high.y, v_point.z); //point2line distance
} else {
retval = distance_vector_point_3d (&v_point, v_high.x, v_high.y, v_high.z); //point2point distance
}
}
}

elog(DEBUG1, "distance (%lg,%lg,%lg %lg,%lg,%lg) <-> (%lg,%lg) = %lg",
v_low.x, v_low.y, v_low.z,
v_high.x, v_high.y, v_high.z,
point->lng, point->lat,
retval
);
break;

default:
elog(ERROR, "unrecognized cube strategy number: %d", strategy);
retval = 0; /* keep compiler quiet */
break;
}
PG_RETURN_FLOAT8(retval);
}



/*
* Represents information about an entry that can be placed to either group
* without affecting overlap over selected axis ("common entry").
Expand Down

0 comments on commit c772c92

Please sign in to comment.