diff --git a/core/src/main/java/com/ibm/drl/hbcp/api/PredictorController.java b/core/src/main/java/com/ibm/drl/hbcp/api/PredictorController.java index 6ff1143..f92f66d 100644 --- a/core/src/main/java/com/ibm/drl/hbcp/api/PredictorController.java +++ b/core/src/main/java/com/ibm/drl/hbcp/api/PredictorController.java @@ -7,10 +7,13 @@ import javax.json.Json; import javax.json.JsonArrayBuilder; import javax.json.JsonObjectBuilder; +import javax.ws.rs.core.MediaType; import com.ibm.drl.hbcp.core.attributes.Attribute; import com.ibm.drl.hbcp.predictor.api.*; import com.ibm.drl.hbcp.util.Environment; +import io.swagger.annotations.Example; +import io.swagger.annotations.ExampleProperty; import org.slf4j.LoggerFactory; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; @@ -76,6 +79,7 @@ public PredictorController(Properties baseProps) throws Exception { jsonRefParser = new JSONRefParser(baseProps); nodeVecsPerConfig = new HashMap<>(); translatingRankers = new HashMap<>(); + // completely disable indexing of values if (!Environment.isPredictionApiOnly()) { // outside of a docker-compose setting we fall back on a baseline // if not done, index the papers PaperIndexer.ensure(baseProps); @@ -145,6 +149,47 @@ public String predictOutcomeEndpoint( useNeuralPrediction).toPrettyString(); } + /** + * Returns the top K interventions by predicted outcome values, to a query asking for a combination of population characteristics, + * possible BC techniques, experimental settings. + * The results are scored with a confidence score. */ + @ApiOperation(value = "Recommends best interventions for user scenarios.", + notes = "Returns the top K interventions by predicted outcome values, to a query asking for a combination of population characteristics, " + + "possible BC techniques, experimental settings." + + "As an example of 'query', you can use: " + + "[{\"id\":\"3673271\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673272\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673273\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673274\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673275\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3675715\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673282\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673283\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673284\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3673285\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3675717\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3675718\",\"type\":\"boolean\",\"value\":true},{\"id\":\"3675719\",\"type\":\"boolean\",\"value\":true},{\"id\":\"5579088\",\"type\":\"numeric\",\"value\":30},{\"id\":\"5579096\",\"type\":\"numeric\",\"value\":50}]") + @RequestMapping(value = "/api/predict/recommend", method = RequestMethod.POST, consumes = "application/json", produces="application/json;charset=utf-8") + public String recommend( + @ApiParam(value = "Maximum number of recommended scenarios (the more, the longer the API call).") + @RequestParam(value="max", required = false, defaultValue = "10") int max, + @ApiParam(value = "A set of attribute-value pairs representing Behavior Change intervention scenarios serving as input") + @RequestBody AttributeValue[] query) throws IOException { + // turn into good old AVPs + List avps = Arrays.stream(query).map(AttributeValue::toAvp).collect(Collectors.toList()); + // produce the candidate interventions to send to the prediction model + List> recommendedInterventionsInfo = RecommendedInterventions.get().getRecommendedInterventions(avps, max).stream() + .map(ArrayList::new) + .collect(Collectors.toList()); + List> queries = RecommendedInterventions.get().getRecommendedScenarios(avps, max); + // run the query + RankedResults predictions = runBatchQueries(queries, 1, true, false, true); + // build the recommendation results + List res = new ArrayList<>(); + for (int i = 0; i < predictions.getResults().size(); i++) { + List intervention = recommendedInterventionsInfo.get(i); + SearchResult prediction = predictions.getResults().get(i); + RecommendedIntervention reco = new RecommendedIntervention(intervention, + prediction.getNode().getNumericValue(), prediction.getScore()); + res.add(reco); + } + // sort by predicted outcome value + res.sort(Comparator.comparing((RecommendedIntervention reco) -> -reco.getPredictedValue())); + // return JSON response + return Jsonable.toPrettyString(Json.createObjectBuilder() + .add("results", Jsonable.getJsonArrayFromCollection(res)) + .build()); + } + public RankedResults predictOutcome( List populationAttributes, List interventionAttributes, @@ -201,6 +246,25 @@ private RankedResults runQuery(List } } + private RankedResults runBatchQueries(List> queries, + int topK, boolean useAnnotations, boolean useEffectSize, + boolean usePredictionApi) { + if (usePredictionApi || Environment.isPredictionApiOnly()) { + // create a new predictor service (very fast) + PredictionServiceConnector connector = PredictionServiceConnector.createForLocalService(); + // request a prediction + List predictionResponses = connector.requestPredictionBatch(queries); + List res = predictionResponses.stream() + .map(predictionResponse -> new SearchResult(new AttributeValueNode( + new AttributeValuePair(Attributes.get().getFromName("Outcome value"), String.valueOf(predictionResponse.getValue()))), + predictionResponse.getConf())) + .collect(Collectors.toList()); + return new RankedResults<>(res); + } else { + return new RankedResults<>(new ArrayList<>()); + } + } + /** @@ -270,7 +334,7 @@ private List interventionOptions() { @RequestMapping(value = "/api/predict/options/all", method = RequestMethod.GET, produces="application/json;charset=utf-8") public String allInputOptions() { List options = Attributes.get().stream() - .map(Attribute::getName) + .map(Attribute::getId) .collect(Collectors.toList()); return attributeInfo(options); } @@ -286,12 +350,18 @@ public String mockRelevantDocs() { } /** Returns all the input attributes handled by the prediction system, clustered by type. */ - @ApiOperation(value = "Returns prediction insights: a comparison of a predicted outcome value and relevant scientific articles " + - "with their reported outcome value.") + @ApiOperation(value = "Returns prediction insights.", + notes = "Returns prediction insights: a comparison of a predicted outcome value and relevant scientific articles " + + "with their reported outcome value. As an example of query, you can use: " + + "[{\"id\":\"5579096\",\"type\":\"numeric\",\"value\":50},{\"id\":\"3673271\",\"type\":\"boolean\",\"value\":true}]") @RequestMapping(value = "/api/predict/insights", method = RequestMethod.POST, consumes = "application/json", produces="application/json;charset=utf-8") public String predictionInsights( @ApiParam(value = "Whether to use the neural prediction model.") @RequestParam(value="useneuralprediction", required = false, defaultValue = "false") boolean useNeuralPrediction, + @ApiParam(value = "A set of attribute-value pairs representing a Behavior Change intervention scenario serving as input", + examples = @Example(value = @ExampleProperty( + mediaType = MediaType.APPLICATION_JSON, + value = "[{\"id\":\"5579096\",\"type\":\"numeric\",\"value\":50},{\"id\":\"3673271\",\"type\":\"boolean\",\"value\":true}]"))) @RequestBody AttributeValue[] query) throws IOException { // turn into good old AVPs List avps = Arrays.stream(query).map(AttributeValue::toAvp).collect(Collectors.toList()); diff --git a/core/src/main/java/com/ibm/drl/hbcp/parser/JSONRefParser.java b/core/src/main/java/com/ibm/drl/hbcp/parser/JSONRefParser.java index 73b6cfb..4c86d49 100644 --- a/core/src/main/java/com/ibm/drl/hbcp/parser/JSONRefParser.java +++ b/core/src/main/java/com/ibm/drl/hbcp/parser/JSONRefParser.java @@ -20,6 +20,7 @@ import com.ibm.drl.hbcp.util.ParsingUtils; import com.ibm.drl.hbcp.util.Props; import lombok.Data; +import net.arnx.jsonic.JSON; import org.apache.commons.lang3.NotImplementedException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -486,11 +487,26 @@ public static void displayAllNameNumberAnnotations() throws IOException { System.out.println(nameNumberValues.size()); } + public static void displayReachAttributes() throws IOException { + JSONRefParser parser = new JSONRefParser(new File("../data/jsons/All_annotations_512papers_05March20.json")); + for (AnnotatedAttributeValuePair avp : parser.getAttributeValuePairs()) { + if (avp.getAttribute().getName().contains(" analysed")) { + System.out.println("For doc: " + avp.getDocName()); + System.out.println(avp); + } + } + } + public static void countAttributes() throws IOException { - JSONRefParser parser = new JSONRefParser(new File("data/jsons/All_annotations_512papers_05March20.json")); + JSONRefParser parser = new JSONRefParser(new File("../data/jsons/All_annotations_512papers_05March20.json")); System.out.println("Attribute count: " + parser.getAttributeValuePairs().getAllAttributeIds().size()); } + public static void countAttributesPA() throws IOException { + JSONRefParser parser = new JSONRefParser(new File("../data/jsons/PhysicalActivity Sprint1ArmsAnd Prioritised47Papers.json")); + System.out.println("Attribute count for PA: " + parser.getAttributeValuePairs().getAllAttributeIds().size()); + } + public static void mainTableGrammar() throws IOException { // TODO: this file is not a resource yet JSONRefParser parser = new JSONRefParser(new File("../data/jsons/TableGrammarAnnotations/Table Grammar Annotations .json"), @@ -519,7 +535,9 @@ public static void mainTableGrammar() throws IOException { } public static void main(String[] args) throws IOException { - mainTableGrammar(); + //mainTableGrammar(); //countAttributes(); + //countAttributesPA(); + displayReachAttributes(); } } \ No newline at end of file diff --git a/core/src/main/java/com/ibm/drl/hbcp/predictor/api/PredictionServiceConnector.java b/core/src/main/java/com/ibm/drl/hbcp/predictor/api/PredictionServiceConnector.java index 7237dad..fffc8a7 100644 --- a/core/src/main/java/com/ibm/drl/hbcp/predictor/api/PredictionServiceConnector.java +++ b/core/src/main/java/com/ibm/drl/hbcp/predictor/api/PredictionServiceConnector.java @@ -14,8 +14,11 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; +import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -26,6 +29,7 @@ public class PredictionServiceConnector { private final int port; private static final String PROTOCOL = "http"; private static final String ENDPOINT = "/hbcp/api/v1.0/predict/outcome/"; + private static final String ENDPOINT_BATCH = "/hbcp/api/v1.0/predict/outcome/batch/"; private static final Logger log = LoggerFactory.getLogger(PredictionServiceConnector.class); @@ -42,6 +46,58 @@ public static PredictionServiceConnector createForLocalService() { return new PredictionServiceConnector(Environment.getPredictionURL(), Environment.getPredictionPort()); } + public List requestPredictionBatch(List> queriesAsAvps) { + // format the body to send in POST request + String queriesString = queriesAsAvps.stream() + .map(queryString -> queryString.stream() + .map(AttributeValueNode::new) + .map(AttributeValueNode::toString) + .collect(Collectors.joining("-"))) + .collect(Collectors.joining("\n")) + .replaceAll("\\n+", "\n"); + // open the HTTP connection with a POST request + HttpURLConnection con = null; + try { + // query the prediction API + URL predictionUrl = new URL(PROTOCOL, host, port, ENDPOINT_BATCH); + con = (HttpURLConnection) predictionUrl.openConnection(); + con.setRequestMethod("POST"); + con.setRequestProperty("Content-Type", "text/plain; charset=utf-8"); + con.setRequestProperty("Accept", "application/json"); + con.setDoInput(true); + con.setDoOutput(true); + // write the queries + try (OutputStream os = con.getOutputStream()) { + byte[] input = queriesString.getBytes(StandardCharsets.UTF_8); + os.write(input, 0, input.length); + } + // get the response + int status = con.getResponseCode(); + if (status < 300) { + try (BufferedReader in = new BufferedReader( + new InputStreamReader(con.getInputStream(), StandardCharsets.UTF_8))) { + PredictionBatchResult batchResult = gson.fromJson(in, new TypeToken() { + }.getType()); + return batchResult.results; + } + } else { + try (BufferedReader in = new BufferedReader(new InputStreamReader(con.getErrorStream()))) { + log.error("Error in batch prediction API response for query: {}", queriesAsAvps); + String line; + while ((line = in.readLine()) != null) { + log.error(line); + } + } + return new ArrayList<>(); + } + } catch (IOException e) { + log.error("Error while batch-querying the prediction API with: " + queriesString, e); + return new ArrayList<>(); + } finally { + if (con != null) con.disconnect(); + } + } + public Optional requestPrediction(List avps) { // format avps into prediction API input strings String queryString = avps.stream() @@ -85,6 +141,11 @@ public static class PredictionWithConfidence { double conf; } + @Value + public static class PredictionBatchResult { + List results; + } + public static void main(String[] args) throws IOException { // test with the prediction docker container running locally on port 5000 PredictionServiceConnector con = new PredictionServiceConnector("127.0.0.1", 5000); diff --git a/core/src/main/java/com/ibm/drl/hbcp/predictor/api/RecommendedIntervention.java b/core/src/main/java/com/ibm/drl/hbcp/predictor/api/RecommendedIntervention.java new file mode 100644 index 0000000..66218e9 --- /dev/null +++ b/core/src/main/java/com/ibm/drl/hbcp/predictor/api/RecommendedIntervention.java @@ -0,0 +1,26 @@ +package com.ibm.drl.hbcp.predictor.api; + +import lombok.Value; + +import javax.json.Json; +import javax.json.JsonValue; +import java.util.List; + +@Value +public class RecommendedIntervention implements Jsonable { + + /** The IDs of the Behavior-Change techniques making up the intervention. */ + List bctIds; + /** Prediction result: predicted outcome value and confidence */ + double predictedValue; + double confidence; + + @Override + public JsonValue toJson() { + return Json.createObjectBuilder() + .add("intervention", Jsonable.getJsonArrayFromStrings(bctIds)) + .add("predictedValue", predictedValue) + .add("confidence", confidence) + .build(); + } +} diff --git a/core/src/main/java/com/ibm/drl/hbcp/predictor/api/RecommendedInterventions.java b/core/src/main/java/com/ibm/drl/hbcp/predictor/api/RecommendedInterventions.java new file mode 100644 index 0000000..3aab44a --- /dev/null +++ b/core/src/main/java/com/ibm/drl/hbcp/predictor/api/RecommendedInterventions.java @@ -0,0 +1,123 @@ +package com.ibm.drl.hbcp.predictor.api; + +import com.google.common.collect.ConcurrentHashMultiset; +import com.google.common.collect.ImmutableMultiset; +import com.google.common.collect.Multiset; +import com.google.common.collect.Multisets; +import com.ibm.drl.hbcp.core.attributes.AttributeType; +import com.ibm.drl.hbcp.core.attributes.AttributeValuePair; +import com.ibm.drl.hbcp.core.attributes.collection.AttributeValueCollection; +import com.ibm.drl.hbcp.parser.AnnotatedAttributeValuePair; +import com.ibm.drl.hbcp.parser.Attributes; +import com.ibm.drl.hbcp.parser.JSONRefParser; +import com.ibm.drl.hbcp.util.Props; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Draw from the smoking cessation annotations to provide the most frequent intervention (as a set of BCTs) reported on + * @author mgleize + */ +public class RecommendedInterventions { + + private RecommendedInterventions() throws IOException { + mostFrequentInterventions = getMostFrequentInterventionsInAnnotations(); + } + + // sets of BCT IDs + private final List> mostFrequentInterventions; + + /** Returns at most 'max' of the interventions sorted by descending frequency */ + public List> getTopKMostFrequentInterventions(int max) { + return mostFrequentInterventions.stream() + .limit(max) + .collect(Collectors.toList()); + } + + /** Transform a BC scenario with a set of all possible BCTs allowed, into a set of recommended scenarios with reasonable + * combinations of BCTs taken from that set. */ + public List> getRecommendedInterventions(List possibleScenarios, int max) { + // extract the BCTs from the possibleScenarios, they are what is "allowed" by the user + Set allowedBctIds = possibleScenarios.stream() + .filter(avp -> avp.getAttribute().getType() == AttributeType.INTERVENTION) + .map(avp -> avp.getAttribute().getId()) + .collect(Collectors.toSet()); + // filter pre-computed recommended interventions based on this + return mostFrequentInterventions.stream() + .filter(allowedBctIds::containsAll) + .limit(max) + // no need to sort, they're already sorted + .collect(Collectors.toList()); + } + + /** Transform a BC scenario with a set of all possible BCTs allowed, into a set of recommended scenarios with reasonable + * combinations of BCTs taken from that set. */ + public List> getRecommendedScenarios(List possibleScenarios, int max) { + // get recommended compatible scenarios + List> recommendedInterventions = getRecommendedInterventions(possibleScenarios, max); + // for each of the recommended intervention, build a new complete scenario + List nonInterventionAvps = possibleScenarios.stream() + .filter(avp -> avp.getAttribute().getType() != AttributeType.INTERVENTION) + .collect(Collectors.toList()); + List> res = new ArrayList<>(); + for (Set recommendedIntervention : recommendedInterventions) { + List recommendedScenario = new ArrayList<>(nonInterventionAvps); + recommendedScenario.addAll(recommendedIntervention.stream() + .map(bctId -> new AttributeValuePair(Attributes.get().getFromId(bctId), "1")) + .collect(Collectors.toList())); + res.add(recommendedScenario); + } + return res; + } + + private static List> getMostFrequentInterventionsInAnnotations() throws IOException { + // read annotations + AttributeValueCollection annotations = new JSONRefParser(Props.loadProperties()) + .getAttributeValuePairs() + // do not forget to distribute the empty arms here + .distributeEmptyArm(); + // get all interventions (through their ID sets) and their frequency + Multiset> res = ConcurrentHashMultiset.create(); + for (String docName : annotations.getDocNames()) { + for (Multiset armifiedAvps : annotations.getArmifiedPairsInDoc(docName).values()) { + Set bctIds = armifiedAvps.stream() + .filter(avp -> avp.getAttribute().getType() == AttributeType.INTERVENTION) + .map(avp -> avp.getAttribute().getId()) + .collect(Collectors.toSet()); + if (!bctIds.isEmpty()) + res.add(bctIds); + } + } + // sort by frequency + ImmutableMultiset> sortedRes = Multisets.copyHighestCountFirst(res); + return new ArrayList<>(sortedRes.elementSet()); + } + + // implements the lazy-initialization thread-safe singleton pattern + private static class LazyHolder { + private static RecommendedInterventions buildInterventions() { + try { + return new RecommendedInterventions(); + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException("IOException when lazy-initializing singleton RecommendedInterventions", e); + } + } + private static final RecommendedInterventions INSTANCE = buildInterventions(); + } + + /** Returns the Attributes collection for the JSON annotation file defined in the default properties */ + public static RecommendedInterventions get() { + return RecommendedInterventions.LazyHolder.INSTANCE; + } + + public static void main(String[] args) { + for (Set intervention : RecommendedInterventions.get().getTopKMostFrequentInterventions(20)) { + System.out.println(intervention); + } + } +} diff --git a/prediction-experiments/python-nb/ov-predict/src/api/predict_app_outcome_and_confidence.py b/prediction-experiments/python-nb/ov-predict/src/api/predict_app_outcome_and_confidence.py index 253d535..19a73d3 100644 --- a/prediction-experiments/python-nb/ov-predict/src/api/predict_app_outcome_and_confidence.py +++ b/prediction-experiments/python-nb/ov-predict/src/api/predict_app_outcome_and_confidence.py @@ -23,7 +23,7 @@ inpH = init_embedding(EMBFILE) -@app.route('/') +@app.route('/hbcp/api/v1.0/predict/') def home(): return "This is the prediction API." @@ -34,6 +34,36 @@ def preprocess_request(data): return pp_string + +@app.route('/hbcp/api/v1.0/predict/outcome/batch/', methods=['POST']) +def outcome_batch(): + # load models + print ("#word-nodes = {}".format(len(inpH.pre_emb))) + trained_model = init_model(inpH) # do this everytime... + trained_model_for_mc = init_model(inpH, saved_model_wts_file=MC_MODEL_WT_FILE, + num_classes=7) # load the mc model for getting confidences + # expected data is a query string per line of body, parse it + queries = request.data.decode('UTF-8').splitlines() + results = [] + for query in queries: + pp_data = preprocess_request(query) + print ("pp_data = {}".format(pp_data)) + #Reshape the vectors... this is not needed as we're using a merged vocab + #print ("#Reshaping the original w/o context vectors word-nodes") + #inpH.modifyW2V(pp_data) + predicted_val = predict_outcome_with_dynamic_vocabchange(inpH, trained_model, pp_data, NODEVEC_DIM) + confidence = predict_confidence(inpH, trained_model_for_mc, pp_data, NODEVEC_DIM) + result = {} + result['value'] = str(predicted_val) + result['conf'] = str(confidence) + results.append(result) + # clear Keras session + k.clear_session() + # prepare top level JSON object + retmap = {'results': results} + return json.dumps(retmap) + + @app.route('/hbcp/api/v1.0/predict/outcome/', methods=['GET']) def outcome(data): print ("#word-nodes = {}".format(len(inpH.pre_emb))) @@ -59,6 +89,7 @@ def outcome(data): return json.dumps(retmap) + if __name__ == '__main__': try: unparsed_port = os.environ['PORT']