From 3e6393d185ef61828b05eb61fd7868c78ebec7d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joaqu=C3=ADn=20T=C3=A1rraga=20Gim=C3=A9nez?= Date: Fri, 9 Jun 2017 14:30:24 +0100 Subject: [PATCH] analysis: compute OR from logistic regression coefficient, #126 --- .../analysis/variant/LogisticRegressionAnalysis.java | 5 +++-- .../hpg/bigdata/app/cli/local/VariantAssocCLITest.java | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/hpg-bigdata-analysis/src/main/java/org/opencb/hpg/bigdata/analysis/variant/LogisticRegressionAnalysis.java b/hpg-bigdata-analysis/src/main/java/org/opencb/hpg/bigdata/analysis/variant/LogisticRegressionAnalysis.java index 5875528a..258ed939 100644 --- a/hpg-bigdata-analysis/src/main/java/org/opencb/hpg/bigdata/analysis/variant/LogisticRegressionAnalysis.java +++ b/hpg-bigdata-analysis/src/main/java/org/opencb/hpg/bigdata/analysis/variant/LogisticRegressionAnalysis.java @@ -1,5 +1,6 @@ package org.opencb.hpg.bigdata.analysis.variant; +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; @@ -7,6 +8,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.functions; /** * Created by jtarraga on 30/05/17. @@ -26,13 +28,13 @@ public void execute() { // print the coefficients and intercept for linear regression System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); + System.out.println("OR (for coeff.) = " + Math.exp(lrModel.coefficients().apply(1))); // summarize the model over the training set and print out some metrics LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); System.out.println("numIterations: " + trainingSummary.totalIterations()); System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory())); -/* // obtain the loss per iteration double[] objectiveHistory = trainingSummary.objectiveHistory(); for (double lossPerIteration : objectiveHistory) { @@ -58,7 +60,6 @@ public void execute() { double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)) .select("threshold").head().getDouble(0); lrModel.setThreshold(bestThreshold); -*/ } public LogisticRegressionAnalysis(String datasetName, String studyName, String depVarName, String indepVarName, diff --git a/hpg-bigdata-app/src/test/java/org/opencb/hpg/bigdata/app/cli/local/VariantAssocCLITest.java b/hpg-bigdata-app/src/test/java/org/opencb/hpg/bigdata/app/cli/local/VariantAssocCLITest.java index 71c53c11..ad5f21db 100644 --- a/hpg-bigdata-app/src/test/java/org/opencb/hpg/bigdata/app/cli/local/VariantAssocCLITest.java +++ b/hpg-bigdata-app/src/test/java/org/opencb/hpg/bigdata/app/cli/local/VariantAssocCLITest.java @@ -31,9 +31,9 @@ public void assoc() { commandLine.append(" -i ").append(inPath); commandLine.append(" -o ").append(outPath); commandLine.append(" --dataset noname"); - //commandLine.append(" --logistic"); - commandLine.append(" --linear"); - commandLine.append(" --pheno Age:s"); + commandLine.append(" --logistic"); + //commandLine.append(" --linear"); + //commandLine.append(" --pheno Age:s"); VariantQueryCLITest.execute(commandLine.toString()); } catch (Exception e) {