Skip to content

Commit

Permalink
Added QPT Support for TPCDS Connector (#2168)
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdulR3hman authored Aug 13, 2024
1 parent 9601944 commit 5a56a56
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
import com.amazonaws.athena.connector.lambda.handlers.MetadataHandler;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest;
Expand All @@ -35,9 +38,12 @@
import com.amazonaws.athena.connector.lambda.metadata.ListSchemasResponse;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connectors.tpcds.qpt.TPCDSQueryPassthrough;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.teradata.tpcds.Table;
import com.teradata.tpcds.column.Column;
Expand All @@ -48,6 +54,7 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -74,6 +81,8 @@ public class TPCDSMetadataHandler
protected static final String SPLIT_SCALE_FACTOR_FIELD = "scaleFactor";
//The list of valid schemas which also convey the scale factor
protected static final Set<String> SCHEMA_NAMES = ImmutableSet.of("tpcds1", "tpcds10", "tpcds100", "tpcds250", "tpcds1000");
// Query Passthrough
protected static final TPCDSQueryPassthrough queryPassthrough = new TPCDSQueryPassthrough();

/**
* used to aid in debugging. Athena will use this name in conjunction with your catalog id
Expand All @@ -98,6 +107,15 @@ protected TPCDSMetadataHandler(
super(keyFactory, secretsManager, athena, SOURCE_TYPE, spillBucket, spillPrefix, configOptions);
}

@Override
public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request)
{
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);

return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
}

/**
* Returns our static list of schemas which correspond to the scale factor of the dataset we will generate.
*
Expand Down Expand Up @@ -151,6 +169,24 @@ public GetTableResponse doGetTable(BlockAllocator allocator, GetTableRequest req
Collections.EMPTY_SET);
}

@Override
public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception
{
logger.info("doGetQueryPassthroughSchema: enter - " + request);
logger.warn("This method is only for testing purpose and should not be used in production");
Table table = TPCDSUtils.validateQptTable(request.getQueryPassthroughArguments());

SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
for (Column nextCol : table.getColumns()) {
schemaBuilder.addField(TPCDSUtils.convertColumn(nextCol));
}

return new GetTableResponse(request.getCatalogName(),
request.getTableName(),
schemaBuilder.build(),
Collections.EMPTY_SET);
}

/**
* We do not support partitioning at this time since Partition Pruning Performance is not part of the dimensions
* we test using TPCDS. By making this a NoOp the Athena Federation SDK will automatically generate a single
Expand All @@ -175,12 +211,28 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request
@Override
public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest request)
{
String catalogName = request.getCatalogName();
int scaleFactor = TPCDSUtils.extractScaleFactor(request.getTableName().getSchemaName());
logger.info("{}: Catalog {}, table {}", request.getQueryId(), request.getTableName().getSchemaName(), request.getTableName().getTableName());
String catalogName;
String schemaName;
String tableName;
if (request.getConstraints().isQueryPassThrough()) {
logger.info("QPT Split Requested");
Map<String, String> qptArguments = request.getConstraints().getQueryPassthroughArguments();
catalogName = qptArguments.get(TPCDSQueryPassthrough.TPCDS_CATALOG);
schemaName = qptArguments.get(TPCDSQueryPassthrough.TPCDS_SCHEMA);
tableName = qptArguments.get(TPCDSQueryPassthrough.TPCDS_TABLE);
}
else {
catalogName = request.getCatalogName();
schemaName = request.getTableName().getSchemaName();
tableName = request.getTableName().getTableName();
}

int scaleFactor = TPCDSUtils.extractScaleFactor(schemaName);
int totalSplits = (int) Math.ceil(((double) scaleFactor / 48D)); //each split would be ~48MB

logger.info("doGetSplits: Generating {} splits for {} at scale factor {}",
totalSplits, request.getTableName(), scaleFactor);
totalSplits, tableName, scaleFactor);

int nextSplit = request.getContinuationToken() == null ? 0 : Integer.parseInt(request.getContinuationToken());
Set<Split> splits = new HashSet<>();
Expand All @@ -198,4 +250,21 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest
logger.info("doGetSplits: exit - " + splits.size());
return new GetSplitsResponse(catalogName, splits);
}

/**
* Helper function that provides a single partition for Query Pass-Through
*
*/
protected GetSplitsResponse setupQueryPassthroughSplit(GetSplitsRequest request)
{
//Every split must have a unique location if we wish to spill to avoid failures
SpillLocation spillLocation = makeSpillLocation(request);

//Since this is QPT query we return a fixed split.
Map<String, String> qptArguments = request.getConstraints().getQueryPassthroughArguments();
return new GetSplitsResponse(request.getCatalogName(),
Split.newBuilder(spillLocation, makeEncryptionKey())
.applyProperties(qptArguments)
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.amazonaws.athena.connector.lambda.data.Block;
import com.amazonaws.athena.connector.lambda.data.BlockSpiller;
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.handlers.RecordHandler;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
import com.amazonaws.services.athena.AmazonAthena;
Expand All @@ -50,7 +49,6 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.amazonaws.athena.connectors.tpcds.TPCDSMetadataHandler.SPLIT_NUMBER_FIELD;
import static com.amazonaws.athena.connectors.tpcds.TPCDSMetadataHandler.SPLIT_SCALE_FACTOR_FIELD;
Expand Down Expand Up @@ -100,7 +98,14 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor
int splitNumber = Integer.parseInt(split.getProperty(SPLIT_NUMBER_FIELD));
int totalNumSplits = Integer.parseInt(split.getProperty(SPLIT_TOTAL_NUMBER_FIELD));
int scaleFactor = Integer.parseInt(split.getProperty(SPLIT_SCALE_FACTOR_FIELD));
Table table = validateTable(recordsRequest.getTableName());

Table table;
if (recordsRequest.getConstraints().isQueryPassThrough()) {
table = TPCDSUtils.validateQptTable(recordsRequest.getConstraints().getQueryPassthroughArguments());
}
else {
table = TPCDSUtils.validateTable(recordsRequest.getTableName());
}

Session session = Session.getDefaultSession()
.withScale(scaleFactor)
Expand All @@ -125,25 +130,6 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor
}
}

/**
* Required that the requested Table be present in the TPCDS generated schema.
*
* @param tableName The fully qualified name of the requested table.
* @return The TPCDS table, if present, otherwise the method throws.
*/
private Table validateTable(TableName tableName)
{
Optional<Table> table = Table.getBaseTables().stream()
.filter(next -> next.getName().equals(tableName.getTableName()))
.findFirst();

if (!table.isPresent()) {
throw new RuntimeException("Unknown table " + tableName);
}

return table.get();
}

/**
* Generates the CellWriters used to convert the TPCDS Generators data to Apache Arrow.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@

import com.amazonaws.athena.connector.lambda.data.FieldBuilder;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connectors.tpcds.qpt.TPCDSQueryPassthrough;
import com.teradata.tpcds.Table;
import com.teradata.tpcds.column.Column;
import com.teradata.tpcds.column.ColumnType;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;

import java.util.Map;
import java.util.Optional;

/**
Expand Down Expand Up @@ -103,4 +105,27 @@ public static Table validateTable(TableName tableName)

return table.get();
}

/**
* Required that the requested Table be present in the TPCDS generated schema
* And For Query Passthrough; only support ONE table per select statement;
* And should only be used for testing.
*
* @param query Query Passthrough
* @return The TPCDS table, if present, otherwise the method throws.
*/
public static Table validateQptTable(Map<String, String> query)
{
String tableName = query.get(TPCDSQueryPassthrough.TPCDS_TABLE);

Optional<Table> table = Table.getBaseTables().stream()
.filter(next -> next.getName().equals(tableName))
.findFirst();

if (!table.isPresent()) {
throw new RuntimeException("Unknown table " + tableName);
}

return table.get();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*-
* #%L
* athena-tpcds
* %%
* Copyright (C) 2019 - 2024 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.tpcds.qpt;

import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.List;

public class TPCDSQueryPassthrough implements QueryPassthroughSignature
{
private static final Logger LOGGER = LoggerFactory.getLogger(TPCDSQueryPassthrough.class);

// Constant value representing the name of the query.
public static final String NAME = "query";

// Constant value representing the domain of the query.
public static final String SCHEMA_NAME = "system";

public static final String TPCDS_CATALOG = "TPCDS_CATALOG";
public static final String TPCDS_SCHEMA = "TPCDS_SCHEMA";
public static final String TPCDS_TABLE = "TPCDS_TABLE";

@Override
public String getFunctionSchema()
{
return SCHEMA_NAME;
}

@Override
public String getFunctionName()
{
return NAME;
}

@Override
public List<String> getFunctionArguments()
{
return Arrays.asList(TPCDS_CATALOG, TPCDS_SCHEMA, TPCDS_TABLE);
}

@Override
public Logger getLogger()
{
return LOGGER;
}
}

0 comments on commit 5a56a56

Please sign in to comment.