Skip to content

Commit

Permalink
Updated SigV4 signing library in Gremlin connection for Neptune conne…
Browse files Browse the repository at this point in the history
…ctor (#1698)
  • Loading branch information
xiazcy authored May 9, 2024
1 parent b81300d commit e080a7c
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 28 deletions.
3 changes: 2 additions & 1 deletion athena-neptune/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
<artifactId>athena-neptune</artifactId>
<version>2022.47.1</version>
<properties>
<gremlinDriverVersion>3.7.2</gremlinDriverVersion>
<!-- make sure gremlin driver version stays within the Neptune supported range -->
<gremlinDriverVersion>3.6.5</gremlinDriverVersion>
<neptune.sigv4.signer.version>2.4.0</neptune.sigv4.signer.version>
</properties>
<dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,23 @@

import com.amazonaws.athena.connectors.neptune.propertygraph.NeptuneGremlinConnection;
import com.amazonaws.athena.connectors.neptune.rdf.NeptuneSparqlConnection;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.neptune.auth.NeptuneNettyHttpSigV4Signer;
import com.amazonaws.neptune.auth.NeptuneSigV4SignerException;
import org.apache.tinkerpop.gremlin.driver.Client;
import org.apache.tinkerpop.gremlin.driver.Cluster;
import org.apache.tinkerpop.gremlin.driver.SigV4WebSocketChannelizer;
import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection;
import org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NeptuneConnection
{
private static Cluster cluster = null;

private static final Logger logger = LoggerFactory.getLogger(NeptuneConnection.class);

private String neptuneEndpoint;
private String neptunePort;
private boolean enabledIAM;
Expand All @@ -45,7 +51,22 @@ protected NeptuneConnection(String neptuneEndpoint, String neptunePort, boolean
.enableSsl(true);

if (enabledIAM) {
builder = builder.channelizer(SigV4WebSocketChannelizer.class);
logger.info("Connecting with IAM auth to https://" + neptuneEndpoint + ":" + neptunePort + " in " + region);
final AWSCredentialsProvider awsCredentialsProvider = new DefaultAWSCredentialsProviderChain();
builder.handshakeInterceptor(r ->
{
try {
NeptuneNettyHttpSigV4Signer sigV4Signer =
new NeptuneNettyHttpSigV4Signer(region, awsCredentialsProvider);
sigV4Signer.signRequest(r);
}
catch (NeptuneSigV4SignerException e) {
logger.error("SIGV4 exception", e);
throw new RuntimeException("Exception occurred while signing the request", e);
}
return r;
}
);
}

cluster = builder.create();
Expand Down Expand Up @@ -77,7 +98,7 @@ public static NeptuneConnection createConnection(java.util.Map<String, String> c
throw new IllegalArgumentException("Unsupported graphType: " + graphType);
}
}

public String getNeptuneEndpoint()
{
return this.neptuneEndpoint;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@
package com.amazonaws.athena.connectors.neptune.propertygraph;

import com.amazonaws.athena.connectors.neptune.NeptuneConnection;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.neptune.auth.NeptuneNettyHttpSigV4Signer;
import com.amazonaws.neptune.auth.NeptuneSigV4SignerException;
import org.apache.tinkerpop.gremlin.driver.Client;
import org.apache.tinkerpop.gremlin.driver.Cluster;
import org.apache.tinkerpop.gremlin.driver.SigV4WebSocketChannelizer;
import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection;
import org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NeptuneGremlinConnection extends NeptuneConnection
{
private static final Logger logger = LoggerFactory.getLogger(NeptuneGremlinConnection.class);
private static Cluster cluster = null;

public NeptuneGremlinConnection(String neptuneEndpoint, String neptunePort, boolean enabledIAM, String region)
Expand All @@ -40,7 +46,22 @@ public NeptuneGremlinConnection(String neptuneEndpoint, String neptunePort, bool
.enableSsl(true);

if (enabledIAM) {
builder = builder.channelizer(SigV4WebSocketChannelizer.class);
logger.info("Connecting with IAM auth to https://" + neptuneEndpoint + ":" + neptunePort + " in " + region);
final AWSCredentialsProvider awsCredentialsProvider = new DefaultAWSCredentialsProviderChain();
builder.handshakeInterceptor(r ->
{
try {
NeptuneNettyHttpSigV4Signer sigV4Signer =
new NeptuneNettyHttpSigV4Signer(region, awsCredentialsProvider);
sigV4Signer.signRequest(r);
}
catch (NeptuneSigV4SignerException e) {
logger.error("SIGV4 exception", e);
throw new RuntimeException("Exception occurred while signing the request", e);
}
return r;
}
);
}

cluster = builder.create();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.tinkerpop.gremlin.structure.T;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Date;
Expand All @@ -50,19 +52,22 @@
* This class is a Utility class to create Extractors for each field type as per
* Schema
*/
public final class CustomSchemaRowWriter
public final class CustomSchemaRowWriter
{
private CustomSchemaRowWriter()
private static final Logger logger = LoggerFactory.getLogger(CustomSchemaRowWriter.class);
private CustomSchemaRowWriter()
{
// Empty private constructor
}

public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field field, java.util.Map<String, String> configOptions)
public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field field, java.util.Map<String, String> configOptions)
{
ArrowType arrowType = field.getType();
Types.MinorType minorType = Types.getMinorTypeForArrowType(arrowType);
logger.debug("writeRowTemplate*" + field.getName() + "*" + minorType + "*");
Boolean enableCaseinsensitivematch = (configOptions.get(Constants.SCHEMA_CASE_INSEN) == null) ? true : Boolean.parseBoolean(configOptions.get(Constants.SCHEMA_CASE_INSEN));

try {
switch (minorType) {
case BIT:
rowWriterBuilder.withExtractor(field.getName(),
Expand All @@ -72,19 +77,22 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Boolean.class)) {
logger.debug("writeRowTemplate BIT*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");

if (fieldValue.getClass().equals(Boolean.class)) {
Boolean booleanValue = Boolean.parseBoolean(fieldValue.toString());
value.value = booleanValue ? 1 : 0;
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) obj.get(field.getName());
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
Boolean booleanValue = Boolean.parseBoolean(objValues.get(0).toString());
value.value = booleanValue ? 1 : 0;
value.isSet = 1;
}
}
}
});
break;

Expand All @@ -102,23 +110,29 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.value = fieldValue.toString();
value.isSet = 1;
}
}
}
else {
Object fieldValue = obj.get(fieldName);
logger.debug("writeRowTemplate VARCHAR*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");

if (fieldValue != null) {
if (fieldValue.getClass().equals(String.class)) {
value.value = fieldValue.toString();
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null) {
value.value = objValues.get(0).toString();
value.isSet = 1;
}
}
}
else {
value.value = "" + fieldValue;
value.isSet = 1;
}
}
}
});
break;
Expand All @@ -131,11 +145,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Date.class)) {
logger.debug("writeRowTemplate DATEMILLI*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Date.class)) {
value.value = ((Date) fieldValue).getTime();
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && (objValues.get(0) != null) && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = ((Date) objValues.get(0)).getTime();
Expand All @@ -153,11 +169,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Integer.class)) {
logger.debug("writeRowTemplate INT*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Integer.class)) {
value.value = Integer.parseInt(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Integer.parseInt(objValues.get(0).toString());
Expand All @@ -175,11 +193,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Long.class)) {
logger.debug("writeRowTemplate BIGINT*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Long.class)) {
value.value = Long.parseLong(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Long.parseLong(objValues.get(0).toString());
Expand All @@ -197,11 +217,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Float.class)) {
logger.debug("writeRowTemplate FLOAT4*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Float.class)) {
value.value = Float.parseFloat(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Float.parseFloat(objValues.get(0).toString());
Expand All @@ -218,12 +240,14 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
Map<String, Object> obj = (Map<String, Object>) contextAsMap(context, enableCaseinsensitivematch);
value.isSet = 0;

Object fieldValue = obj.get(field.getName());
if (fieldValue.getClass().equals(Double.class)) {
Object fieldValue = obj.get(fieldName);
logger.debug("writeRowTemplate FLOAT8*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Double.class)) {
value.value = Double.parseDouble(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Double.parseDouble(objValues.get(0).toString());
Expand All @@ -234,9 +258,14 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie

break;
}
}
catch (Throwable e) {
logger.error("writeRowTemplate exception for *" + field.getName() + "*" + minorType + "*", e);
throw new RuntimeException(e);
}
}

private static Map<String, Object> contextAsMap(Object context, boolean caseInsensitive)
private static Map<String, Object> contextAsMap(Object context, boolean caseInsensitive)
{
Map<String, Object> contextAsMap = (Map<String, Object>) context;
Object fieldValueID = contextAsMap.get(T.id);
Expand Down

0 comments on commit e080a7c

Please sign in to comment.