Skip to content

Commit

Permalink
Get schema from gremlin/sparql query results for Neptune qpt and refa…
Browse files Browse the repository at this point in the history
…ctor.
  • Loading branch information
VenkatasivareddyTR committed Jan 22, 2025
1 parent c28f9d6 commit a53d82f
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import com.amazonaws.athena.connector.lambda.QueryStatusChecker;
import com.amazonaws.athena.connector.lambda.data.BlockAllocator;
import com.amazonaws.athena.connector.lambda.data.BlockWriter;
import com.amazonaws.athena.connector.lambda.data.FieldBuilder;
import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
Expand All @@ -47,9 +45,7 @@
import com.amazonaws.athena.connectors.neptune.rdf.NeptuneSparqlConnection;
import com.google.common.collect.ImmutableMap;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.tinkerpop.gremlin.driver.Client;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversal;
Expand All @@ -64,7 +60,6 @@
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -282,6 +277,7 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge
queryPassthrough.verify(qptArguments);
String schemaName = qptArguments.get(NeptuneQueryPassthrough.DATABASE);
String tableName = qptArguments.get(NeptuneQueryPassthrough.COLLECTION);
String componentTypeValue = qptArguments.get(NeptuneQueryPassthrough.COMPONENT_TYPE);
TableName tableNameObj = new TableName(schemaName, tableName);
Schema schema;
Enums.GraphType graphType = Enums.GraphType.PROPERTYGRAPH;
Expand All @@ -304,7 +300,7 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge
if (responseObj instanceof Map && gremlinQuery.contains(Constants.GREMLIN_QUERY_SUPPORT_TYPE)) {
logger.info("NeptuneMetadataHandler doGetQueryPassthroughSchema gremlinQuery with valueMap");
Map graphTraversalObj = (Map) responseObj;
schema = getSchemaFromResults(graphTraversalObj);
schema = NeptuneSchemaUtils.getSchemaFromResults(graphTraversalObj, componentTypeValue, tableName);
return new GetTableResponse(request.getCatalogName(), tableNameObj, schema);
}
else {
Expand All @@ -324,8 +320,8 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge
NeptuneSparqlConnection neptuneSparqlConnection = (NeptuneSparqlConnection) neptuneConnection;
neptuneSparqlConnection.runQuery(sparqlQuery);
if (neptuneSparqlConnection.hasNext()) {
Map<String, Object> resultsMap = neptuneSparqlConnection.next(true);
schema = getSchemaFromResults(resultsMap);
Map<String, Object> resultsMap = neptuneSparqlConnection.next();
schema = NeptuneSchemaUtils.getSchemaFromResults(resultsMap, componentTypeValue, tableName);
return new GetTableResponse(request.getCatalogName(), tableNameObj, schema);
}
else {
Expand All @@ -336,72 +332,4 @@ public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, Ge
throw new IllegalArgumentException("Unsupported graphType: " + graphType);
}
}

private Schema getSchemaFromResults(Map resultsMap)
{
Schema schema;
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
//Building schema from gremlin/sparql query results.
resultsMap.forEach((columnName, columnValue) -> buildSchema(columnName.toString(), columnValue, schemaBuilder));
schema = schemaBuilder.build();

return schema;
}

private void buildSchema(String columnName, Object columnValue, SchemaBuilder schemaBuilder)
{
schemaBuilder.addField(getArrowFieldForNeptune(columnName, columnValue));
}

/**
* Infers the type of a field from Neptune data.
*
* @param key The key of the field we are attempting to infer.
* @param value A value from the key whose type we are attempting to infer.
* @return The Apache Arrow field definition of the inferred key/value.
*/
private Field getArrowFieldForNeptune(String key, Object value)
{
if (value instanceof String) {
return new Field(key, FieldType.nullable(Types.MinorType.VARCHAR.getType()), null);
}
else if (value instanceof Integer) {
return new Field(key, FieldType.nullable(Types.MinorType.INT.getType()), null);
}
else if (value instanceof Long) {
return new Field(key, FieldType.nullable(Types.MinorType.BIGINT.getType()), null);
}
else if (value instanceof Boolean) {
return new Field(key, FieldType.nullable(Types.MinorType.BIT.getType()), null);
}
else if (value instanceof Float) {
return new Field(key, FieldType.nullable(Types.MinorType.FLOAT4.getType()), null);
}
else if (value instanceof Double) {
return new Field(key, FieldType.nullable(Types.MinorType.FLOAT8.getType()), null);
}
else if (value instanceof java.util.Date) {
return new Field(key, FieldType.nullable(Types.MinorType.DATEMILLI.getType()), null);
}
else if (value instanceof java.util.UUID) {
return new Field(key, FieldType.nullable(Types.MinorType.VARCHAR.getType()), null);
}
else if (value instanceof List) {
Field child;
if (((List<?>) value).isEmpty()) {
logger.warn("getArrowFieldForNeptune: Encountered an empty List for field[{}], defaulting to List<String> due to type erasure.", key);
return FieldBuilder.newBuilder(key, Types.MinorType.LIST.getType()).addStringField("").build();
}
else {
child = getArrowFieldForNeptune("", ((List<?>) value).get(0));
}
return new Field(key, FieldType.nullable(Types.MinorType.LIST.getType()),
Collections.singletonList(child));
}

String className = (value == null || value.getClass() == null) ? "null" : value.getClass().getName();
logger.warn("Unknown type[{}] for field[{}], defaulting to varchar.", className, key);
return new Field(key, FieldType.nullable(Types.MinorType.VARCHAR.getType()), null);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*-
* #%L
* athena-neptune
* %%
* Copyright (C) 2019 - 2020 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.neptune;

import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.math.BigInteger;
import java.util.List;
import java.util.Map;

/**
* Collection of helpful utilities that handle Neptune schema inference, type, and naming conversion.
*/
public class NeptuneSchemaUtils
{
private static final Logger logger = LoggerFactory.getLogger(NeptuneSchemaUtils.class);

private NeptuneSchemaUtils() {}

public static Schema getSchemaFromResults(Map resultsMap, String componentTypeValue, String tableName)
{
Schema schema;
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
//Building schema from gremlin/sparql query results.
resultsMap.forEach((columnName, columnValue) -> buildSchema(columnName.toString(), columnValue, schemaBuilder));
schemaBuilder.addMetadata(Constants.SCHEMA_COMPONENT_TYPE, componentTypeValue);
schemaBuilder.addMetadata(Constants.SCHEMA_GLABEL, tableName);
schema = schemaBuilder.build();
return schema;
}

private static void buildSchema(String columnName, Object columnValue, SchemaBuilder schemaBuilder)
{
schemaBuilder.addField(getArrowFieldForNeptune(columnName, columnValue));
}

/**
* Infers the type of a field from Neptune data.
*
* @param key The key of the field we are attempting to infer.
* @param value A value from the key whose type we are attempting to infer.
* @return The Apache Arrow field definition of the inferred key/value.
*/
private static Field getArrowFieldForNeptune(String key, Object value)
{
if (value instanceof String || value instanceof java.util.UUID) {
return new Field(key, FieldType.nullable(Types.MinorType.VARCHAR.getType()), null);
}
else if (value instanceof Integer) {
return new Field(key, FieldType.nullable(Types.MinorType.INT.getType()), null);
}
else if (value instanceof BigInteger) {
return new Field(key, FieldType.nullable(Types.MinorType.BIGINT.getType()), null);
}
else if (value instanceof Long) {
return new Field(key, FieldType.nullable(Types.MinorType.BIGINT.getType()), null);
}
else if (value instanceof Boolean) {
return new Field(key, FieldType.nullable(Types.MinorType.BIT.getType()), null);
}
else if (value instanceof Float) {
return new Field(key, FieldType.nullable(Types.MinorType.FLOAT4.getType()), null);
}
else if (value instanceof Double) {
return new Field(key, FieldType.nullable(Types.MinorType.FLOAT8.getType()), null);
}
else if (value instanceof java.util.Date) {
return new Field(key, FieldType.nullable(Types.MinorType.DATEMILLI.getType()), null);
}
else if (value instanceof List) {
return getArrowFieldForNeptune(key, ((List<?>) value).get(0));
}

String className = (value == null || value.getClass() == null) ? "null" : value.getClass().getName();
logger.warn("Unknown type[{}] for field[{}], defaulting to varchar.", className, key);
return new Field(key, FieldType.nullable(Types.MinorType.VARCHAR.getType()), null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public final class NeptuneQueryPassthrough implements QueryPassthroughSignature
public static final String DATABASE = "DATABASE";
public static final String COLLECTION = "COLLECTION";
public static final String QUERY = "QUERY";
public static final String COMPONENT_TYPE = "COMPONENTTYPE";

@Override
public String getFunctionSchema()
Expand All @@ -56,7 +57,7 @@ public String getFunctionName()
@Override
public List<String> getFunctionArguments()
{
return Arrays.asList(DATABASE, COLLECTION, QUERY);
return Arrays.asList(DATABASE, COLLECTION, QUERY, COMPONENT_TYPE);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public boolean hasNext()
return false;
}

public Map<String, Object> next(boolean trimURI)
public Map<String, Object> next()
{
Map<String, Object> ret = new HashMap<String, Object>();
BindingSet bindingSet = this.queryResult.next();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,9 @@ else if (queryMode.equals(Constants.QUERY_MODE_CLASS)) {
final GeneratedRowWriter rowWriter = builder.build();

// get results
String strim = recordsRequest.getSchema().getCustomMetadata().get(Constants.SCHEMA_STRIP_URI);
boolean trimURI = strim == null ? false : Boolean.parseBoolean(strim);
neptuneConnection.runQuery(sparql.toString());
while (neptuneConnection.hasNext() && queryStatusChecker.isQueryRunning()) {
Map<String, Object> result = neptuneConnection.next(trimURI);
Map<String, Object> result = neptuneConnection.next();
spiller.writeRows((final Block block, final int rowNum) -> {
return (rowWriter.writeRow(block, rowNum, (Object) result) ? 1 : 0);
});
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*-
* #%L
* athena-neptune
* %%
* Copyright (C) 2019 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.neptune;

import com.google.common.collect.ImmutableMap;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.junit.MockitoJUnitRunner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Map;

import static org.junit.Assert.assertEquals;

@RunWith(MockitoJUnitRunner.class)
public class NeptuneSchemaUtilsTest {
private static final Logger logger = LoggerFactory.getLogger(NeptuneSchemaUtilsTest.class);

private final Map<String, Object> objectMap = ImmutableMap.of(
"col1", "String",
"col2", 1,
"col3", 10.33,
"col4", true,
"col5", new BigInteger("12345678901234567890"));

private final String COMPONENT_TYPE = "vertex";

@Test
public void getSchemaFromResults() {
logger.info("getSchemaFromResults - enter");
Schema schema = NeptuneSchemaUtils.getSchemaFromResults(objectMap, COMPONENT_TYPE, "test");

assertEquals(schema.getFields().size(), objectMap.size());
assertEquals("Utf8", schema.findField("col1").getType().toString());
assertEquals("Int(32, true)", schema.findField("col2").getType().toString());
assertEquals("FloatingPoint(DOUBLE)", schema.findField("col3").getType().toString());
assertEquals("Bool", schema.findField("col4").getType().toString());
assertEquals("Int(64, true)", schema.findField("col5").getType().toString());

assertEquals(COMPONENT_TYPE, schema.getCustomMetadata().get(Constants.SCHEMA_COMPONENT_TYPE));
logger.info("getSchemaFromResults - exit");
}
}

0 comments on commit a53d82f

Please sign in to comment.