From e080a7c55a34de6a5726badda24560cd441e58ee Mon Sep 17 00:00:00 2001 From: Yang Xia <55853655+xiazcy@users.noreply.github.com> Date: Wed, 8 May 2024 18:44:09 -0700 Subject: [PATCH] Updated SigV4 signing library in Gremlin connection for Neptune connector (#1698) --- athena-neptune/pom.xml | 3 +- .../connectors/neptune/NeptuneConnection.java | 29 ++++++-- .../NeptuneGremlinConnection.java | 25 ++++++- .../rowwriters/CustomSchemaRowWriter.java | 71 +++++++++++++------ 4 files changed, 100 insertions(+), 28 deletions(-) diff --git a/athena-neptune/pom.xml b/athena-neptune/pom.xml index 5e5bfda028..73f1657cb8 100644 --- a/athena-neptune/pom.xml +++ b/athena-neptune/pom.xml @@ -9,7 +9,8 @@ athena-neptune 2022.47.1 - 3.7.2 + + 3.6.5 2.4.0 diff --git a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneConnection.java b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneConnection.java index 18a55f5bc0..bb6368016f 100644 --- a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneConnection.java +++ b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneConnection.java @@ -21,17 +21,23 @@ import com.amazonaws.athena.connectors.neptune.propertygraph.NeptuneGremlinConnection; import com.amazonaws.athena.connectors.neptune.rdf.NeptuneSparqlConnection; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.neptune.auth.NeptuneNettyHttpSigV4Signer; +import com.amazonaws.neptune.auth.NeptuneSigV4SignerException; import org.apache.tinkerpop.gremlin.driver.Client; import org.apache.tinkerpop.gremlin.driver.Cluster; -import org.apache.tinkerpop.gremlin.driver.SigV4WebSocketChannelizer; import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection; import org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource; import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class NeptuneConnection { private static Cluster cluster = null; - + private static final Logger logger = LoggerFactory.getLogger(NeptuneConnection.class); + private String neptuneEndpoint; private String neptunePort; private boolean enabledIAM; @@ -45,7 +51,22 @@ protected NeptuneConnection(String neptuneEndpoint, String neptunePort, boolean .enableSsl(true); if (enabledIAM) { - builder = builder.channelizer(SigV4WebSocketChannelizer.class); + logger.info("Connecting with IAM auth to https://" + neptuneEndpoint + ":" + neptunePort + " in " + region); + final AWSCredentialsProvider awsCredentialsProvider = new DefaultAWSCredentialsProviderChain(); + builder.handshakeInterceptor(r -> + { + try { + NeptuneNettyHttpSigV4Signer sigV4Signer = + new NeptuneNettyHttpSigV4Signer(region, awsCredentialsProvider); + sigV4Signer.signRequest(r); + } + catch (NeptuneSigV4SignerException e) { + logger.error("SIGV4 exception", e); + throw new RuntimeException("Exception occurred while signing the request", e); + } + return r; + } + ); } cluster = builder.create(); @@ -77,7 +98,7 @@ public static NeptuneConnection createConnection(java.util.Map c throw new IllegalArgumentException("Unsupported graphType: " + graphType); } } - + public String getNeptuneEndpoint() { return this.neptuneEndpoint; diff --git a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/propertygraph/NeptuneGremlinConnection.java b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/propertygraph/NeptuneGremlinConnection.java index 6ce130defd..55b77a3d9e 100644 --- a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/propertygraph/NeptuneGremlinConnection.java +++ b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/propertygraph/NeptuneGremlinConnection.java @@ -20,15 +20,21 @@ package com.amazonaws.athena.connectors.neptune.propertygraph; import com.amazonaws.athena.connectors.neptune.NeptuneConnection; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.neptune.auth.NeptuneNettyHttpSigV4Signer; +import com.amazonaws.neptune.auth.NeptuneSigV4SignerException; import org.apache.tinkerpop.gremlin.driver.Client; import org.apache.tinkerpop.gremlin.driver.Cluster; -import org.apache.tinkerpop.gremlin.driver.SigV4WebSocketChannelizer; import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection; import org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource; import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class NeptuneGremlinConnection extends NeptuneConnection { + private static final Logger logger = LoggerFactory.getLogger(NeptuneGremlinConnection.class); private static Cluster cluster = null; public NeptuneGremlinConnection(String neptuneEndpoint, String neptunePort, boolean enabledIAM, String region) @@ -40,7 +46,22 @@ public NeptuneGremlinConnection(String neptuneEndpoint, String neptunePort, bool .enableSsl(true); if (enabledIAM) { - builder = builder.channelizer(SigV4WebSocketChannelizer.class); + logger.info("Connecting with IAM auth to https://" + neptuneEndpoint + ":" + neptunePort + " in " + region); + final AWSCredentialsProvider awsCredentialsProvider = new DefaultAWSCredentialsProviderChain(); + builder.handshakeInterceptor(r -> + { + try { + NeptuneNettyHttpSigV4Signer sigV4Signer = + new NeptuneNettyHttpSigV4Signer(region, awsCredentialsProvider); + sigV4Signer.signRequest(r); + } + catch (NeptuneSigV4SignerException e) { + logger.error("SIGV4 exception", e); + throw new RuntimeException("Exception occurred while signing the request", e); + } + return r; + } + ); } cluster = builder.create(); diff --git a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/propertygraph/rowwriters/CustomSchemaRowWriter.java b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/propertygraph/rowwriters/CustomSchemaRowWriter.java index ed0004ecc0..d3093fafce 100644 --- a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/propertygraph/rowwriters/CustomSchemaRowWriter.java +++ b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/propertygraph/rowwriters/CustomSchemaRowWriter.java @@ -40,6 +40,8 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.tinkerpop.gremlin.structure.T; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.Date; @@ -50,19 +52,22 @@ * This class is a Utility class to create Extractors for each field type as per * Schema */ -public final class CustomSchemaRowWriter +public final class CustomSchemaRowWriter { - private CustomSchemaRowWriter() + private static final Logger logger = LoggerFactory.getLogger(CustomSchemaRowWriter.class); + private CustomSchemaRowWriter() { // Empty private constructor } - public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field field, java.util.Map configOptions) + public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field field, java.util.Map configOptions) { ArrowType arrowType = field.getType(); Types.MinorType minorType = Types.getMinorTypeForArrowType(arrowType); + logger.debug("writeRowTemplate*" + field.getName() + "*" + minorType + "*"); Boolean enableCaseinsensitivematch = (configOptions.get(Constants.SCHEMA_CASE_INSEN) == null) ? true : Boolean.parseBoolean(configOptions.get(Constants.SCHEMA_CASE_INSEN)); + try { switch (minorType) { case BIT: rowWriterBuilder.withExtractor(field.getName(), @@ -72,19 +77,22 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie value.isSet = 0; Object fieldValue = obj.get(fieldName); - if (fieldValue.getClass().equals(Boolean.class)) { + logger.debug("writeRowTemplate BIT*" + field.getName() + "*" + minorType + "*" + + (fieldValue == null ? "" : fieldValue.getClass()) + "*"); + + if (fieldValue.getClass().equals(Boolean.class)) { Boolean booleanValue = Boolean.parseBoolean(fieldValue.toString()); value.value = booleanValue ? 1 : 0; value.isSet = 1; } - else { + else if (fieldValue instanceof ArrayList) { ArrayList objValues = (ArrayList) obj.get(field.getName()); if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) { Boolean booleanValue = Boolean.parseBoolean(objValues.get(0).toString()); value.value = booleanValue ? 1 : 0; value.isSet = 1; } - } + } }); break; @@ -102,23 +110,29 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie value.value = fieldValue.toString(); value.isSet = 1; } - } + } else { Object fieldValue = obj.get(fieldName); + logger.debug("writeRowTemplate VARCHAR*" + field.getName() + "*" + minorType + "*" + + (fieldValue == null ? "" : fieldValue.getClass()) + "*"); if (fieldValue != null) { if (fieldValue.getClass().equals(String.class)) { value.value = fieldValue.toString(); value.isSet = 1; } - else { + else if (fieldValue instanceof ArrayList) { ArrayList objValues = (ArrayList) fieldValue; if (objValues != null && objValues.get(0) != null) { value.value = objValues.get(0).toString(); value.isSet = 1; } } - } + else { + value.value = "" + fieldValue; + value.isSet = 1; + } + } } }); break; @@ -131,11 +145,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie value.isSet = 0; Object fieldValue = obj.get(fieldName); - if (fieldValue.getClass().equals(Date.class)) { + logger.debug("writeRowTemplate DATEMILLI*" + field.getName() + "*" + minorType + "*" + + (fieldValue == null ? "" : fieldValue.getClass()) + "*"); + if (fieldValue.getClass().equals(Date.class)) { value.value = ((Date) fieldValue).getTime(); value.isSet = 1; } - else { + else if (fieldValue instanceof ArrayList) { ArrayList objValues = (ArrayList) fieldValue; if (objValues != null && (objValues.get(0) != null) && !(objValues.get(0).toString().trim().isEmpty())) { value.value = ((Date) objValues.get(0)).getTime(); @@ -153,11 +169,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie value.isSet = 0; Object fieldValue = obj.get(fieldName); - if (fieldValue.getClass().equals(Integer.class)) { + logger.debug("writeRowTemplate INT*" + field.getName() + "*" + minorType + "*" + + (fieldValue == null ? "" : fieldValue.getClass()) + "*"); + if (fieldValue.getClass().equals(Integer.class)) { value.value = Integer.parseInt(fieldValue.toString()); value.isSet = 1; } - else { + else if (fieldValue instanceof ArrayList) { ArrayList objValues = (ArrayList) fieldValue; if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) { value.value = Integer.parseInt(objValues.get(0).toString()); @@ -175,11 +193,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie value.isSet = 0; Object fieldValue = obj.get(fieldName); - if (fieldValue.getClass().equals(Long.class)) { + logger.debug("writeRowTemplate BIGINT*" + field.getName() + "*" + minorType + "*" + + (fieldValue == null ? "" : fieldValue.getClass()) + "*"); + if (fieldValue.getClass().equals(Long.class)) { value.value = Long.parseLong(fieldValue.toString()); value.isSet = 1; } - else { + else if (fieldValue instanceof ArrayList) { ArrayList objValues = (ArrayList) fieldValue; if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) { value.value = Long.parseLong(objValues.get(0).toString()); @@ -197,11 +217,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie value.isSet = 0; Object fieldValue = obj.get(fieldName); - if (fieldValue.getClass().equals(Float.class)) { + logger.debug("writeRowTemplate FLOAT4*" + field.getName() + "*" + minorType + "*" + + (fieldValue == null ? "" : fieldValue.getClass()) + "*"); + if (fieldValue.getClass().equals(Float.class)) { value.value = Float.parseFloat(fieldValue.toString()); value.isSet = 1; } - else { + else if (fieldValue instanceof ArrayList) { ArrayList objValues = (ArrayList) fieldValue; if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) { value.value = Float.parseFloat(objValues.get(0).toString()); @@ -218,12 +240,14 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie Map obj = (Map) contextAsMap(context, enableCaseinsensitivematch); value.isSet = 0; - Object fieldValue = obj.get(field.getName()); - if (fieldValue.getClass().equals(Double.class)) { + Object fieldValue = obj.get(fieldName); + logger.debug("writeRowTemplate FLOAT8*" + field.getName() + "*" + minorType + "*" + + (fieldValue == null ? "" : fieldValue.getClass()) + "*"); + if (fieldValue.getClass().equals(Double.class)) { value.value = Double.parseDouble(fieldValue.toString()); value.isSet = 1; } - else { + else if (fieldValue instanceof ArrayList) { ArrayList objValues = (ArrayList) fieldValue; if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) { value.value = Double.parseDouble(objValues.get(0).toString()); @@ -234,9 +258,14 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie break; } + } + catch (Throwable e) { + logger.error("writeRowTemplate exception for *" + field.getName() + "*" + minorType + "*", e); + throw new RuntimeException(e); + } } - private static Map contextAsMap(Object context, boolean caseInsensitive) + private static Map contextAsMap(Object context, boolean caseInsensitive) { Map contextAsMap = (Map) context; Object fieldValueID = contextAsMap.get(T.id);