Skip to content

Commit

Permalink
Merge pull request #12 from HumanBehaviourChangeProject/release-v4.1
Browse files Browse the repository at this point in the history
Add recommender call to API
  • Loading branch information
MartinGleize authored Mar 25, 2021
2 parents 4b853f5 + beddbdb commit 71ea066
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 6 deletions.
76 changes: 73 additions & 3 deletions core/src/main/java/com/ibm/drl/hbcp/api/PredictorController.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import javax.json.Json;
import javax.json.JsonArrayBuilder;
import javax.json.JsonObjectBuilder;
import javax.ws.rs.core.MediaType;

import com.ibm.drl.hbcp.core.attributes.Attribute;
import com.ibm.drl.hbcp.predictor.api.*;
import com.ibm.drl.hbcp.util.Environment;
import io.swagger.annotations.Example;
import io.swagger.annotations.ExampleProperty;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
Expand Down Expand Up @@ -76,6 +79,7 @@ public PredictorController(Properties baseProps) throws Exception {
jsonRefParser = new JSONRefParser(baseProps);
nodeVecsPerConfig = new HashMap<>();
translatingRankers = new HashMap<>();
// completely disable indexing of values
if (!Environment.isPredictionApiOnly()) { // outside of a docker-compose setting we fall back on a baseline
// if not done, index the papers
PaperIndexer.ensure(baseProps);
Expand Down Expand Up @@ -145,6 +149,47 @@ public String predictOutcomeEndpoint(
useNeuralPrediction).toPrettyString();
}

/**
* Returns the top K interventions by predicted outcome values, to a query asking for a combination of population characteristics,
* possible BC techniques, experimental settings.
* The results are scored with a confidence score. */
@ApiOperation(value = "Recommends best interventions for user scenarios.",
notes = "Returns the top K interventions by predicted outcome values, to a query asking for a combination of population characteristics, " +
"possible BC techniques, experimental settings." +
"As an example of 'query', you can use: " +
"[{\"id\":\"3673271\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673272\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673273\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673274\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673275\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3675715\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673282\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673283\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673284\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673285\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3675717\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3675718\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3675719\",\"type\":\"boolean\",\"value\":true},{\"id\":\"5579088\",\"type\":\"numeric\",\"value\":30},{\"id\":\"5579096\",\"type\":\"numeric\",\"value\":50}]")
@RequestMapping(value = "/api/predict/recommend", method = RequestMethod.POST, consumes = "application/json", produces="application/json;charset=utf-8")
public String recommend(
@ApiParam(value = "Maximum number of recommended scenarios (the more, the longer the API call).")
@RequestParam(value="max", required = false, defaultValue = "10") int max,
@ApiParam(value = "A set of attribute-value pairs representing Behavior Change intervention scenarios serving as input")
@RequestBody AttributeValue[] query) throws IOException {
// turn into good old AVPs
List<AttributeValuePair> avps = Arrays.stream(query).map(AttributeValue::toAvp).collect(Collectors.toList());
// produce the candidate interventions to send to the prediction model
List<List<String>> recommendedInterventionsInfo = RecommendedInterventions.get().getRecommendedInterventions(avps, max).stream()
.map(ArrayList::new)
.collect(Collectors.toList());
List<List<AttributeValuePair>> queries = RecommendedInterventions.get().getRecommendedScenarios(avps, max);
// run the query
RankedResults<SearchResult> predictions = runBatchQueries(queries, 1, true, false, true);
// build the recommendation results
List<RecommendedIntervention> res = new ArrayList<>();
for (int i = 0; i < predictions.getResults().size(); i++) {
List<String> intervention = recommendedInterventionsInfo.get(i);
SearchResult prediction = predictions.getResults().get(i);
RecommendedIntervention reco = new RecommendedIntervention(intervention,
prediction.getNode().getNumericValue(), prediction.getScore());
res.add(reco);
}
// sort by predicted outcome value
res.sort(Comparator.comparing((RecommendedIntervention reco) -> -reco.getPredictedValue()));
// return JSON response
return Jsonable.toPrettyString(Json.createObjectBuilder()
.add("results", Jsonable.getJsonArrayFromCollection(res))
.build());
}

public RankedResults<SearchResult> predictOutcome(
List<String> populationAttributes,
List<String> interventionAttributes,
Expand Down Expand Up @@ -201,6 +246,25 @@ private RankedResults<SearchResult> runQuery(List<? extends AttributeValuePair>
}
}

private RankedResults<SearchResult> runBatchQueries(List<List<AttributeValuePair>> queries,
int topK, boolean useAnnotations, boolean useEffectSize,
boolean usePredictionApi) {
if (usePredictionApi || Environment.isPredictionApiOnly()) {
// create a new predictor service (very fast)
PredictionServiceConnector connector = PredictionServiceConnector.createForLocalService();
// request a prediction
List<PredictionServiceConnector.PredictionWithConfidence> predictionResponses = connector.requestPredictionBatch(queries);
List<SearchResult> res = predictionResponses.stream()
.map(predictionResponse -> new SearchResult(new AttributeValueNode(
new AttributeValuePair(Attributes.get().getFromName("Outcome value"), String.valueOf(predictionResponse.getValue()))),
predictionResponse.getConf()))
.collect(Collectors.toList());
return new RankedResults<>(res);
} else {
return new RankedResults<>(new ArrayList<>());
}
}



/**
Expand Down Expand Up @@ -270,7 +334,7 @@ private List<String> interventionOptions() {
@RequestMapping(value = "/api/predict/options/all", method = RequestMethod.GET, produces="application/json;charset=utf-8")
public String allInputOptions() {
List<String> options = Attributes.get().stream()
.map(Attribute::getName)
.map(Attribute::getId)
.collect(Collectors.toList());
return attributeInfo(options);
}
Expand All @@ -286,12 +350,18 @@ public String mockRelevantDocs() {
}

/** Returns all the input attributes handled by the prediction system, clustered by type. */
@ApiOperation(value = "Returns prediction insights: a comparison of a predicted outcome value and relevant scientific articles " +
"with their reported outcome value.")
@ApiOperation(value = "Returns prediction insights.",
notes = "Returns prediction insights: a comparison of a predicted outcome value and relevant scientific articles " +
"with their reported outcome value. As an example of query, you can use: " +
"[{\"id\":\"5579096\",\"type\":\"numeric\",\"value\":50},{\"id\":\"3673271\",\"type\":\"boolean\",\"value\":true}]")
@RequestMapping(value = "/api/predict/insights", method = RequestMethod.POST, consumes = "application/json", produces="application/json;charset=utf-8")
public String predictionInsights(
@ApiParam(value = "Whether to use the neural prediction model.")
@RequestParam(value="useneuralprediction", required = false, defaultValue = "false") boolean useNeuralPrediction,
@ApiParam(value = "A set of attribute-value pairs representing a Behavior Change intervention scenario serving as input",
examples = @Example(value = @ExampleProperty(
mediaType = MediaType.APPLICATION_JSON,
value = "[{\"id\":\"5579096\",\"type\":\"numeric\",\"value\":50},{\"id\":\"3673271\",\"type\":\"boolean\",\"value\":true}]")))
@RequestBody AttributeValue[] query) throws IOException {
// turn into good old AVPs
List<AttributeValuePair> avps = Arrays.stream(query).map(AttributeValue::toAvp).collect(Collectors.toList());
Expand Down
22 changes: 20 additions & 2 deletions core/src/main/java/com/ibm/drl/hbcp/parser/JSONRefParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.ibm.drl.hbcp.util.ParsingUtils;
import com.ibm.drl.hbcp.util.Props;
import lombok.Data;
import net.arnx.jsonic.JSON;
import org.apache.commons.lang3.NotImplementedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -486,11 +487,26 @@ public static void displayAllNameNumberAnnotations() throws IOException {
System.out.println(nameNumberValues.size());
}

public static void displayReachAttributes() throws IOException {
JSONRefParser parser = new JSONRefParser(new File("../data/jsons/All_annotations_512papers_05March20.json"));
for (AnnotatedAttributeValuePair avp : parser.getAttributeValuePairs()) {
if (avp.getAttribute().getName().contains(" analysed")) {
System.out.println("For doc: " + avp.getDocName());
System.out.println(avp);
}
}
}

public static void countAttributes() throws IOException {
JSONRefParser parser = new JSONRefParser(new File("data/jsons/All_annotations_512papers_05March20.json"));
JSONRefParser parser = new JSONRefParser(new File("../data/jsons/All_annotations_512papers_05March20.json"));
System.out.println("Attribute count: " + parser.getAttributeValuePairs().getAllAttributeIds().size());
}

public static void countAttributesPA() throws IOException {
JSONRefParser parser = new JSONRefParser(new File("../data/jsons/PhysicalActivity Sprint1ArmsAnd Prioritised47Papers.json"));
System.out.println("Attribute count for PA: " + parser.getAttributeValuePairs().getAllAttributeIds().size());
}

public static void mainTableGrammar() throws IOException {
// TODO: this file is not a resource yet
JSONRefParser parser = new JSONRefParser(new File("../data/jsons/TableGrammarAnnotations/Table Grammar Annotations .json"),
Expand Down Expand Up @@ -519,7 +535,9 @@ public static void mainTableGrammar() throws IOException {
}

public static void main(String[] args) throws IOException {
mainTableGrammar();
//mainTableGrammar();
//countAttributes();
//countAttributesPA();
displayReachAttributes();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
Expand All @@ -26,6 +29,7 @@ public class PredictionServiceConnector {
private final int port;
private static final String PROTOCOL = "http";
private static final String ENDPOINT = "/hbcp/api/v1.0/predict/outcome/";
private static final String ENDPOINT_BATCH = "/hbcp/api/v1.0/predict/outcome/batch/";

private static final Logger log = LoggerFactory.getLogger(PredictionServiceConnector.class);

Expand All @@ -42,6 +46,58 @@ public static PredictionServiceConnector createForLocalService() {
return new PredictionServiceConnector(Environment.getPredictionURL(), Environment.getPredictionPort());
}

public List<PredictionWithConfidence> requestPredictionBatch(List<List<AttributeValuePair>> queriesAsAvps) {
// format the body to send in POST request
String queriesString = queriesAsAvps.stream()
.map(queryString -> queryString.stream()
.map(AttributeValueNode::new)
.map(AttributeValueNode::toString)
.collect(Collectors.joining("-")))
.collect(Collectors.joining("\n"))
.replaceAll("\\n+", "\n");
// open the HTTP connection with a POST request
HttpURLConnection con = null;
try {
// query the prediction API
URL predictionUrl = new URL(PROTOCOL, host, port, ENDPOINT_BATCH);
con = (HttpURLConnection) predictionUrl.openConnection();
con.setRequestMethod("POST");
con.setRequestProperty("Content-Type", "text/plain; charset=utf-8");
con.setRequestProperty("Accept", "application/json");
con.setDoInput(true);
con.setDoOutput(true);
// write the queries
try (OutputStream os = con.getOutputStream()) {
byte[] input = queriesString.getBytes(StandardCharsets.UTF_8);
os.write(input, 0, input.length);
}
// get the response
int status = con.getResponseCode();
if (status < 300) {
try (BufferedReader in = new BufferedReader(
new InputStreamReader(con.getInputStream(), StandardCharsets.UTF_8))) {
PredictionBatchResult batchResult = gson.fromJson(in, new TypeToken<PredictionBatchResult>() {
}.getType());
return batchResult.results;
}
} else {
try (BufferedReader in = new BufferedReader(new InputStreamReader(con.getErrorStream()))) {
log.error("Error in batch prediction API response for query: {}", queriesAsAvps);
String line;
while ((line = in.readLine()) != null) {
log.error(line);
}
}
return new ArrayList<>();
}
} catch (IOException e) {
log.error("Error while batch-querying the prediction API with: " + queriesString, e);
return new ArrayList<>();
} finally {
if (con != null) con.disconnect();
}
}

public Optional<PredictionWithConfidence> requestPrediction(List<? extends AttributeValuePair> avps) {
// format avps into prediction API input strings
String queryString = avps.stream()
Expand Down Expand Up @@ -85,6 +141,11 @@ public static class PredictionWithConfidence {
double conf;
}

@Value
public static class PredictionBatchResult {
List<PredictionWithConfidence> results;
}

public static void main(String[] args) throws IOException {
// test with the prediction docker container running locally on port 5000
PredictionServiceConnector con = new PredictionServiceConnector("127.0.0.1", 5000);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.ibm.drl.hbcp.predictor.api;

import lombok.Value;

import javax.json.Json;
import javax.json.JsonValue;
import java.util.List;

@Value
public class RecommendedIntervention implements Jsonable {

/** The IDs of the Behavior-Change techniques making up the intervention. */
List<String> bctIds;
/** Prediction result: predicted outcome value and confidence */
double predictedValue;
double confidence;

@Override
public JsonValue toJson() {
return Json.createObjectBuilder()
.add("intervention", Jsonable.getJsonArrayFromStrings(bctIds))
.add("predictedValue", predictedValue)
.add("confidence", confidence)
.build();
}
}
Loading

0 comments on commit 71ea066

Please sign in to comment.