Skip to content

Commit

Permalink
Added experimental knn option
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 3, 2024
1 parent c95a743 commit a83030c
Show file tree
Hide file tree
Showing 13 changed files with 139 additions and 150 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 5.4.0 (unreleased)

- Added experimental `knn` option
- Added experimental support for `_raw` to `where` option
- Added warning for `exists` with non-`true` values
- Added warning for full reindex and `:queue` mode
Expand Down
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1845,9 +1845,21 @@ To query nested data, use dot notation.
Product.search("san", fields: ["store.city"], where: {"store.zip_code" => 12345})
```
## Nearest Neighbors
## Nearest Neighbors [experimental]
You can use custom mapping and searching to index vectors and perform k-nearest neighbor search. See the examples for [Elasticsearch](examples/elasticsearch_knn.rb) and [OpenSearch](examples/opensearch_knn.rb).
*Available for Elasticsearch 8.6+ and OpenSearch 2.4+*
```ruby
class Product < ApplicationRecord
searchkick knn: {embedding: {dimensions: 3}}
end
```
Reindex and search with:
```ruby
Product.search(knn: {field: :embedding, vector: [1, 2, 3]})
```
## Reference
Expand Down
9 changes: 0 additions & 9 deletions examples/Gemfile

This file was deleted.

62 changes: 0 additions & 62 deletions examples/elasticsearch_knn.rb

This file was deleted.

74 changes: 0 additions & 74 deletions examples/opensearch_knn.rb

This file was deleted.

9 changes: 9 additions & 0 deletions lib/searchkick.rb
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ def self.server_below?(version, true_version = false)
Gem::Version.new(server_version.split("-")[0]) < Gem::Version.new(version.split("-")[0])
end

# private
def self.knn_support?
if opensearch?
!server_below?("2.4.0", true)
else
!server_below?("8.6.0")
end
end

def self.search(term = "*", model: nil, **options, &block)
options = options.dup
klass = model
Expand Down
35 changes: 35 additions & 0 deletions lib/searchkick/index_options.rb
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,20 @@ def generate_settings
max_shingle_diff: 4
}

if options[:knn]
unless Searchkick.knn_support?
if Searchkick.opensearch?
raise Error, "knn requires OpenSearch 2.4+"
else
raise Error, "knn requires Elasticsearch 8.6+"
end
end

if Searchkick.opensearch?
settings[:index][:knn] = true
end
end

if options[:case_sensitive]
settings[:analysis][:analyzer].each do |_, analyzer|
analyzer[:filter].delete("lowercase")
Expand Down Expand Up @@ -406,6 +420,27 @@ def generate_mappings
mapping[field] = shape_options.merge(type: "geo_shape")
end

(options[:knn] || []).each do |field, knn_options|
if Searchkick.opensearch?
mapping[field.to_s] = {
type: "knn_vector",
dimension: knn_options[:dimensions],
method: {
name: "hnsw",
space_type: "cosinesimil",
engine: "lucene"
}
}
else
mapping[field.to_s] = {
type: "dense_vector",
dims: knn_options[:dimensions],
index: true,
similarity: "cosine"
}
end
end

if options[:inheritance]
mapping[:type] = keyword_mapping
end
Expand Down
2 changes: 1 addition & 1 deletion lib/searchkick/model.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ def searchkick(**options)
options = Searchkick.model_options.merge(options)

unknown_keywords = options.keys - [:_all, :_type, :batch_size, :callbacks, :case_sensitive, :conversions, :deep_paging, :default_fields,
:filterable, :geo_shape, :highlight, :ignore_above, :index_name, :index_prefix, :inheritance, :language,
:filterable, :geo_shape, :highlight, :ignore_above, :index_name, :index_prefix, :inheritance, :knn, :language,
:locations, :mappings, :match, :max_result_window, :merge_mappings, :routing, :searchable, :search_synonyms, :settings, :similarity,
:special_characters, :stem, :stemmer, :stem_conversions, :stem_exclusion, :stemmer_override, :suggest, :synonyms, :text_end,
:text_middle, :text_start, :unscope, :word, :word_end, :word_middle, :word_start]
Expand Down
34 changes: 33 additions & 1 deletion lib/searchkick/query.rb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Query
def initialize(klass, term = "*", **options)
unknown_keywords = options.keys - [:aggs, :block, :body, :body_options, :boost,
:boost_by, :boost_by_distance, :boost_by_recency, :boost_where, :conversions, :conversions_term, :debug, :emoji, :exclude, :explain,
:fields, :highlight, :includes, :index_name, :indices_boost, :limit, :load,
:fields, :highlight, :includes, :index_name, :indices_boost, :knn, :limit, :load,
:match, :misspellings, :models, :model_includes, :offset, :operator, :order, :padding, :page, :per_page, :profile,
:request_params, :routing, :scope_results, :scroll, :select, :similar, :smart_aggs, :suggest, :total_entries, :track, :type, :where]
raise ArgumentError, "unknown keywords: #{unknown_keywords.join(", ")}" if unknown_keywords.any?
Expand Down Expand Up @@ -526,6 +526,38 @@ def prepare
end
end

# knn
knn = options[:knn]
if knn
if term != "*"
raise ArgumentError, "Hybrid search not supported yet"
end

field = knn[:field]
vector = knn[:vector]
k = per_page + offset
filter = payload.delete(:query)

if Searchkick.opensearch?
payload[:query] = {
knn: {
field.to_sym => {
vector: vector,
k: k,
filter: filter
}
}
}
else
payload[:knn] = {
field: field,
query_vector: vector,
k: k,
filter: filter
}
end
end

# pagination
pagination_options = options[:page] || options[:limit] || options[:per_page] || options[:offset] || options[:padding]
if !options[:body] || pagination_options
Expand Down
36 changes: 36 additions & 0 deletions test/knn_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
require_relative "test_helper"

class KnnTest < Minitest::Test
def setup
skip unless Searchkick.knn_support?
super
end

def test_basic
store [{name: "A", embedding: [1, 2, 3]}, {name: "B", embedding: [-1, -2, -3]}]
assert_order "*", ["A", "B"], knn: {field: :embedding, vector: [1, 2, 3]}

scores = Product.search(knn: {field: :embedding, vector: [1, 2, 3]}).hits.map { |v| v["_score"] }
assert_in_delta 1, scores[0]
assert_in_delta 0, scores[1]
end

def test_where
store [
{name: "A", store_id: 1, embedding: [1, 2, 3]},
{name: "B", store_id: 2, embedding: [1, 2, 3]},
{name: "C", store_id: 1, embedding: [-1, -2, -3]},
]
assert_order "*", ["A", "C"], knn: {field: :embedding, vector: [1, 2, 3]}, where: {store_id: 1}
end

def test_pagination
store [
{name: "A", embedding: [1, 2, 3]},
{name: "B", embedding: [1, 2, 0]},
{name: "C", embedding: [-1, -2, 0]},
{name: "D", embedding: [-1, -2, -3]}
]
assert_order "*", ["B", "C"], knn: {field: :embedding, vector: [1, 2, 3]}, limit: 2, offset: 1
end
end
3 changes: 2 additions & 1 deletion test/models/product.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class Product
highlight: [:name],
filterable: [:name, :color, :description],
similarity: "BM25",
match: ENV["MATCH"] ? ENV["MATCH"].to_sym : nil
match: ENV["MATCH"] ? ENV["MATCH"].to_sym : nil,
knn: Searchkick.knn_support? ? {embedding: {dimensions: 3}} : nil

attr_accessor :conversions, :user_ids, :aisle, :details

Expand Down
7 changes: 7 additions & 0 deletions test/support/activerecord.rb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
t.decimal :longitude, precision: 10, scale: 7
t.text :description
t.text :alt_description
t.text :embedding
t.timestamps null: true
end

Expand Down Expand Up @@ -75,6 +76,12 @@

class Product < ActiveRecord::Base
belongs_to :store

if ActiveRecord::VERSION::STRING.to_f >= 7.1
serialize :embedding, coder: JSON
else
serialize :embedding, JSON
end
end

class Store < ActiveRecord::Base
Expand Down
1 change: 1 addition & 0 deletions test/support/mongoid.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Product
field :longitude, type: BigDecimal
field :description
field :alt_description
field :embedding, type: Array
end

class Store
Expand Down

0 comments on commit a83030c

Please sign in to comment.