Skip to content

Commit

Permalink
add REST test and javadoc
Browse files Browse the repository at this point in the history
Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl committed Jan 25, 2025
1 parent 368f1af commit 0960694
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ public static JsonObject getJsonObjectFromString(String jsonString) {
* JsonPath expression (e.g., "$.store.book[0].title").
* @return true if the path exists in the JSON object, false otherwise.
* @throws IllegalArgumentException if the json object is null or if the path is null or empty.
* @throws PathNotFoundException if there's an error in parsing the JSON or the path.
*/
public static boolean pathExists(Object json, String path) {
if (json == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ private void rewriteQueryString(
* @param queryString the original query string
* @param requestListener the {@link ActionListener} to be notified when the query string or query template is updated
* @param processOutputMap the list of output mappings
* @param requestContext
* @param requestContext the requestContext can be carried over search processors into the search pipeline
* @return an {@link ActionListener} that handles the response from the ML model inference
*/
private ActionListener<Map<Integer, MLOutput>> createRewriteRequestListener(
Expand Down Expand Up @@ -316,6 +316,9 @@ private void updateIncomeQueryObject(
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);
requestContext.setAttribute(newQueryField, modelOutputValue);

// if output mapping is using jsonpath starts with $. or use dot path starts with ext.
// to allow writing to search extension, try to prepare the path in the query,
// for example {"ext":{"ml_inference":{}}}
if (newQueryField.startsWith("$.ext.") || newQueryField.startsWith("ext.")) {
incomeQueryObject = StringUtils.prepareNestedStructures(incomeQueryObject, newQueryField);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,24 @@ default String convertToDotPath(String path) {
return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", "");
}

/**
* Sets the request context from the extensions in the SearchRequest.
*
* This method processes the extensions in the provided SearchRequest and sets
* corresponding attributes in the PipelineProcessingContext. It specifically
* handles two types of extensions:
* 1. MLInferenceRequestParametersExtBuilder
* 2. GenerativeQAParamExtBuilder
*
* For each recognized extension, it extracts parameters and sets them as
* attributes in the requestContext with appropriate prefixes.
*
* @param request The SearchRequest containing the extensions to process.
* This should be a valid SearchRequest that may contain
* ML Inference or Generative QA extensions.
* @param requestContext The PipelineProcessingContext where attributes will be set.
* This context will be updated with parameters from the extensions.
* */
default void setRequestContextFromExt(SearchRequest request, PipelineProcessingContext requestContext) {

List<SearchExtBuilder> extBuilderList = request.source().ext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
*
*/
public class MLInferenceRequestParametersUtil {

Check warning on line 25 in plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtil.java

View check run for this annotation

Codecov / codecov/patch

plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtil.java#L25

Added line #L25 was not covered by tests

/**
* Extracts ML Inference Request Parameters from a SearchRequest.
*
* This method examines the provided SearchRequest for ML-inference parameters
* that are embedded within the request's extensions. It specifically looks for
* the MLInferenceRequestParametersExtBuilder and extracts the ML Inference
* Request Parameters if present.
* */
public static MLInferenceRequestParameters getMLInferenceRequestParameters(SearchRequest searchRequest) {
MLInferenceRequestParametersExtBuilder mLInferenceRequestParametersExtBuilder = null;
if (searchRequest.source() != null && searchRequest.source().ext() != null && !searchRequest.source().ext().isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,73 @@ public void testMLInferenceProcessorRemoteModelRewriteQueryString() throws Excep
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school");
}

/**
* Tests the ML inference processor with a remote model to rewrite the query string.
* It creates a search pipeline with the ML inference processor,
* the ml inference processor takes model input from search extension
* and then performs a search using the pipeline. The test verifies that the query string is rewritten
* correctly based on the inference results from the remote model.
*
* @throws Exception if any error occurs during the test
*/
public void testMLInferenceProcessorRemoteModelRewriteQueryStringWithSearchExtension() throws Exception {
// Skip test if key is null
if (OPENAI_KEY == null) {
return;
}
String createPipelineRequestBody = "{\n"
+ " \"request_processors\": [\n"
+ " {\n"
+ " \"ml_inference\": {\n"
+ " \"tag\": \"ml_inference\",\n"
+ " \"description\": \"This processor is going to run ml inference during search request\",\n"
+ " \"model_id\": \""
+ this.openAIChatModelId
+ "\",\n"
+ " \"input_map\": [\n"
+ " {\n"
+ " \"input\": \"ext.ml_inference.query_text\"\n"
+ " }\n"
+ " ],\n"
+ " \"output_map\": [\n"
+ " {\n"
+ " \"query.term.diary_embedding_size.value\": \"data[0].embedding.length()\"\n"
+ " }\n"
+ " ],\n"
+ " \"ignore_missing\":false,\n"
+ " \"ignore_failure\": false\n"
+ " \n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}";

String query = "{\n"
+ " \"query\": {\n"
+ " \"term\": {\n"
+ " \"diary_embedding_size\": {\n"
+ " \"value\": \"any\"\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"ext\": {\n"
+ " \"ml_inference\": {\n"
+ " \"query_text\": \"foo\"\n"
+ " }\n"
+ " }\n"
+ "}";
String index_name = "daily_index";
String pipelineName = "diary_embedding_pipeline";
createSearchPipelineProcessor(createPipelineRequestBody, pipelineName);

Map response = searchWithPipeline(client(), index_name, pipelineName, query);

Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536");
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "rainy");
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy");
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school");
}

/**
* Tests the ML inference processor with a remote model to rewrite the query type.
* It creates a search pipeline with the ML inference processor configured to rewrite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,78 @@ public void testMLInferenceProcessorRemoteModelStringField() throws Exception {
Assert.assertEquals((Double) embeddingList.get(1), 0.87109375, 0.005);
}

/**
* Tests the MLInferenceSearchResponseProcessor with a remote model and
* read the model input from a string field in ml inference search extension
* It creates a search pipeline with the processor configured to use the remote model,
* runs one to one prediction by sending one document to one prediction
* performs a search using the pipeline, and verifies the inference results.
*
* @throws Exception if any error occurs during the test
*/
@Test
public void testMLInferenceProcessorRemoteModelStringFieldWithSearchExtension() throws Exception {
// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null) {
return;
}
String createPipelineRequestBody = "{\n"
+ " \"response_processors\": [\n"
+ " {\n"
+ " \"ml_inference\": {\n"
+ " \"tag\": \"ml_inference\",\n"
+ " \"description\": \"This processor is going to run ml inference during search response\",\n"
+ " \"model_id\": \""
+ this.bedrockEmbeddingModelId
+ "\",\n"
+ " \"input_map\": [\n"
+ " {\n"
+ " \"input\": \"$._request.ext.ml_inference.query_text\"\n"
+ " }\n"
+ " ],\n"
+ " \"output_map\": [\n"
+ " {\n"
+ " \"weather_embedding\": \"$.embedding\"\n"
+ " }\n"
+ " ],\n"
+ " \"ignore_missing\": false,\n"
+ " \"ignore_failure\": false,\n"
+ " \"one_to_one\": true\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}";

String query = "{\n"
+ " \"query\": {\n"
+ " \"term\": {\n"
+ " \"diary\": {\n"
+ " \"value\": \"happy\"\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"ext\": {\n"
+ " \"ml_inference\": {\n"
+ " \"query_text\": \"sunny\"\n"
+ " }\n"
+ " }\n"
+ "}";

String index_name = "daily_index";
String pipelineName = "weather_embedding_pipeline";
createSearchPipelineProcessor(createPipelineRequestBody, pipelineName);

Map response = searchWithPipeline(client(), index_name, pipelineName, query);
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536");
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny");
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy");
Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school");
List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.weather_embedding");
Assert.assertEquals(embeddingList.size(), 1536);
Assert.assertEquals((Double) embeddingList.get(0), 0.734375, 0.005);
Assert.assertEquals((Double) embeddingList.get(1), 0.87109375, 0.005);
}

/**
* Tests the MLInferenceSearchResponseProcessor with a remote model and a nested list field as input.
* It creates a search pipeline with the processor configured to use the remote model,
Expand Down

0 comments on commit 0960694

Please sign in to comment.