diff --git a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandler.java b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandler.java index ecca3559e5..cd52e12683 100644 --- a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandler.java +++ b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchMetadataHandler.java @@ -29,6 +29,8 @@ import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation; import com.amazonaws.athena.connector.lambda.handlers.MetadataHandler; +import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest; +import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse; import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest; import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse; import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest; @@ -38,7 +40,9 @@ import com.amazonaws.athena.connector.lambda.metadata.ListSchemasResponse; import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest; import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse; +import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; +import com.amazonaws.athena.connectors.cloudwatch.qpt.CloudwatchQueryPassthrough; import com.amazonaws.services.athena.AmazonAthena; import com.amazonaws.services.logs.AWSLogs; import com.amazonaws.services.logs.AWSLogsClientBuilder; @@ -46,8 +50,11 @@ import com.amazonaws.services.logs.model.DescribeLogGroupsResult; import com.amazonaws.services.logs.model.DescribeLogStreamsRequest; import com.amazonaws.services.logs.model.DescribeLogStreamsResult; +import com.amazonaws.services.logs.model.GetQueryResultsResult; import com.amazonaws.services.logs.model.LogStream; +import com.amazonaws.services.logs.model.ResultField; import com.amazonaws.services.secretsmanager.AWSSecretsManager; +import com.google.common.collect.ImmutableMap; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; @@ -60,11 +67,13 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.TimeoutException; import static com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest.UNLIMITED_PAGE_SIZE_VALUE; import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchExceptionFilter.EXCEPTION_FILTER; +import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchUtils.getResult; /** * Handles metadata requests for the Athena Cloudwatch Connector. @@ -117,6 +126,7 @@ public class CloudwatchMetadataHandler private final AWSLogs awsLogs; private final ThrottlingInvoker invoker; private final CloudwatchTableResolver tableResolver; + private final CloudwatchQueryPassthrough queryPassthrough = new CloudwatchQueryPassthrough(); public CloudwatchMetadataHandler(java.util.Map configOptions) { @@ -241,6 +251,9 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques @Override public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTableLayoutRequest request) { + if (request.getTableName().getQualifiedTableName().equalsIgnoreCase(queryPassthrough.getFunctionSignature())) { + return; + } partitionSchemaBuilder.addField(LOG_STREAM_SIZE_FIELD, new ArrowType.Int(64, true)); partitionSchemaBuilder.addField(LOG_GROUP_FIELD, Types.MinorType.VARCHAR.getType()); } @@ -257,6 +270,10 @@ public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTabl public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker) throws Exception { + if (request.getTableName().getQualifiedTableName().equalsIgnoreCase(queryPassthrough.getFunctionSignature())) { + return; + } + CloudwatchTableName cwTableName = tableResolver.validateTable(request.getTableName()); DescribeLogStreamsRequest cwRequest = new DescribeLogStreamsRequest(cwTableName.getLogGroupName()); @@ -290,6 +307,15 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request @Override public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest request) { + if (request.getConstraints().isQueryPassThrough()) { + //Since this is QPT query we return a fixed split. + Map qptArguments = request.getConstraints().getQueryPassthroughArguments(); + return new GetSplitsResponse(request.getCatalogName(), + Split.newBuilder(makeSpillLocation(request), makeEncryptionKey()) + .applyProperties(qptArguments) + .build()); + } + int partitionContd = decodeContinuationToken(request); Set splits = new HashSet<>(); Block partitions = request.getPartitions(); @@ -325,6 +351,35 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest return new GetSplitsResponse(request.getCatalogName(), splits, null); } + @Override + public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request) + { + ImmutableMap.Builder> capabilities = ImmutableMap.builder(); + queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions); + + return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build()); + } + + @Override + public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception + { + if (!request.isQueryPassthrough()) { + throw new IllegalArgumentException("No Query passed through [{}]" + request); + } + // to get column names with limit 1 + GetQueryResultsResult getQueryResultsResult = getResult(invoker, awsLogs, request.getQueryPassthroughArguments(), 1); + SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); + if (!getQueryResultsResult.getResults().isEmpty()) { + for (ResultField field : getQueryResultsResult.getResults().get(0)) { + schemaBuilder.addField(field.getField(), Types.MinorType.VARCHAR.getType()); + } + } + + return new GetTableResponse(request.getCatalogName(), + request.getTableName(), + schemaBuilder.build()); + } + /** * Used to handle paginated requests. * diff --git a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java index b96ee64e81..a5d29f0f9b 100644 --- a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java +++ b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java @@ -31,13 +31,16 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; +import com.amazonaws.athena.connectors.cloudwatch.qpt.CloudwatchQueryPassthrough; import com.amazonaws.services.athena.AmazonAthena; import com.amazonaws.services.athena.AmazonAthenaClientBuilder; import com.amazonaws.services.logs.AWSLogs; import com.amazonaws.services.logs.AWSLogsClientBuilder; import com.amazonaws.services.logs.model.GetLogEventsRequest; import com.amazonaws.services.logs.model.GetLogEventsResult; +import com.amazonaws.services.logs.model.GetQueryResultsResult; import com.amazonaws.services.logs.model.OutputLogEvent; +import com.amazonaws.services.logs.model.ResultField; import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.amazonaws.services.secretsmanager.AWSSecretsManager; @@ -46,6 +49,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.List; +import java.util.Map; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; @@ -54,6 +59,7 @@ import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchMetadataHandler.LOG_MSG_FIELD; import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchMetadataHandler.LOG_STREAM_FIELD; import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchMetadataHandler.LOG_TIME_FIELD; +import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchUtils.getResult; /** * Handles data read record requests for the Athena Cloudwatch Connector. @@ -73,15 +79,16 @@ public class CloudwatchRecordHandler private final ThrottlingInvoker invoker; private final AtomicLong count = new AtomicLong(0); private final AWSLogs awsLogs; + private final CloudwatchQueryPassthrough queryPassthrough = new CloudwatchQueryPassthrough(); public CloudwatchRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), - AWSSecretsManagerClientBuilder.defaultClient(), - AmazonAthenaClientBuilder.defaultClient(), - AWSLogsClientBuilder.defaultClient(), - configOptions); + AmazonS3ClientBuilder.defaultClient(), + AWSSecretsManagerClientBuilder.defaultClient(), + AmazonAthenaClientBuilder.defaultClient(), + AWSLogsClientBuilder.defaultClient(), + configOptions); } @VisibleForTesting @@ -99,54 +106,79 @@ protected CloudwatchRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsMa */ @Override protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) - throws TimeoutException + throws TimeoutException, InterruptedException { - String continuationToken = null; - TableName tableName = recordsRequest.getTableName(); - Split split = recordsRequest.getSplit(); - invoker.setBlockSpiller(spiller); - do { - final String actualContinuationToken = continuationToken; - GetLogEventsResult logEventsResult = invoker.invoke(() -> awsLogs.getLogEvents( - pushDownConstraints(recordsRequest.getConstraints(), - new GetLogEventsRequest() - .withLogGroupName(split.getProperty(LOG_GROUP_FIELD)) - //We use the property instead of the table name because of the special all_streams table - .withLogStreamName(split.getProperty(LOG_STREAM_FIELD)) - .withNextToken(actualContinuationToken) - // must be set to use nextToken correctly - .withStartFromHead(true) - ))); - - if (continuationToken == null || !continuationToken.equals(logEventsResult.getNextForwardToken())) { - continuationToken = logEventsResult.getNextForwardToken(); - } - else { - continuationToken = null; - } + if (recordsRequest.getConstraints().isQueryPassThrough()) { + getQueryPassthreoughResults(spiller, recordsRequest); + } + else { + String continuationToken = null; + TableName tableName = recordsRequest.getTableName(); + Split split = recordsRequest.getSplit(); + invoker.setBlockSpiller(spiller); + do { + final String actualContinuationToken = continuationToken; + GetLogEventsResult logEventsResult = invoker.invoke(() -> awsLogs.getLogEvents( + pushDownConstraints(recordsRequest.getConstraints(), + new GetLogEventsRequest() + .withLogGroupName(split.getProperty(LOG_GROUP_FIELD)) + //We use the property instead of the table name because of the special all_streams table + .withLogStreamName(split.getProperty(LOG_STREAM_FIELD)) + .withNextToken(actualContinuationToken) + // must be set to use nextToken correctly + .withStartFromHead(true) + ))); - for (OutputLogEvent ole : logEventsResult.getEvents()) { - spiller.writeRows((Block block, int rowNum) -> { - boolean matched = true; - matched &= block.offerValue(LOG_STREAM_FIELD, rowNum, split.getProperty(LOG_STREAM_FIELD)); - matched &= block.offerValue(LOG_TIME_FIELD, rowNum, ole.getTimestamp()); - matched &= block.offerValue(LOG_MSG_FIELD, rowNum, ole.getMessage()); - return matched ? 1 : 0; - }); + if (continuationToken == null || !continuationToken.equals(logEventsResult.getNextForwardToken())) { + continuationToken = logEventsResult.getNextForwardToken(); + } + else { + continuationToken = null; + } + + for (OutputLogEvent ole : logEventsResult.getEvents()) { + spiller.writeRows((Block block, int rowNum) -> { + boolean matched = true; + matched &= block.offerValue(LOG_STREAM_FIELD, rowNum, split.getProperty(LOG_STREAM_FIELD)); + matched &= block.offerValue(LOG_TIME_FIELD, rowNum, ole.getTimestamp()); + matched &= block.offerValue(LOG_MSG_FIELD, rowNum, ole.getMessage()); + return matched ? 1 : 0; + }); + } + + logger.info("readWithConstraint: LogGroup[{}] LogStream[{}] Continuation[{}] rows[{}]", + tableName.getSchemaName(), tableName.getTableName(), continuationToken, + logEventsResult.getEvents().size()); } + while (continuationToken != null && queryStatusChecker.isQueryRunning()); + } + } - logger.info("readWithConstraint: LogGroup[{}] LogStream[{}] Continuation[{}] rows[{}]", - tableName.getSchemaName(), tableName.getTableName(), continuationToken, - logEventsResult.getEvents().size()); + private void getQueryPassthreoughResults(BlockSpiller spiller, ReadRecordsRequest recordsRequest) throws TimeoutException, InterruptedException + { + Map qptArguments = recordsRequest.getConstraints().getQueryPassthroughArguments(); + queryPassthrough.verify(qptArguments); + GetQueryResultsResult getQueryResultsResult = getResult(invoker, awsLogs, qptArguments, Integer.parseInt(qptArguments.get(CloudwatchQueryPassthrough.LIMIT))); + + for (List resultList : getQueryResultsResult.getResults()) { + spiller.writeRows((Block block, int rowNum) -> { + for (ResultField resultField : resultList) { + boolean matched = true; + matched &= block.offerValue(resultField.getField(), rowNum, resultField.getValue()); + if (!matched) { + return 0; + } + } + return 1; + }); } - while (continuationToken != null && queryStatusChecker.isQueryRunning()); } /** * Attempts to push down predicates into Cloudwatch Logs by decorating the Cloudwatch Logs request. * * @param constraints The constraints for the read as provided by Athena based on the customer's query. - * @param request The Cloudwatch Logs request to inject predicates to. + * @param request The Cloudwatch Logs request to inject predicates to. * @return The decorated Cloudwatch Logs request. * @note This impl currently only pushing down SortedRangeSet filters (>=, =<, between) on the log time column. */ diff --git a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchUtils.java b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchUtils.java new file mode 100644 index 0000000000..5c19ec17ee --- /dev/null +++ b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchUtils.java @@ -0,0 +1,89 @@ +/*- + * #%L + * athena-cloudwatch + * %% + * Copyright (C) 2019 - 2024 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.cloudwatch; + +import com.amazonaws.athena.connector.lambda.ThrottlingInvoker; +import com.amazonaws.athena.connectors.cloudwatch.qpt.CloudwatchQueryPassthrough; +import com.amazonaws.services.logs.AWSLogs; +import com.amazonaws.services.logs.model.GetQueryResultsRequest; +import com.amazonaws.services.logs.model.GetQueryResultsResult; +import com.amazonaws.services.logs.model.StartQueryRequest; +import com.amazonaws.services.logs.model.StartQueryResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Map; +import java.util.concurrent.TimeoutException; + +public final class CloudwatchUtils +{ + private static final Logger logger = LoggerFactory.getLogger(CloudwatchUtils.class); + public static final int RESULT_TIMEOUT = 10; + private CloudwatchUtils() {} + public static StartQueryRequest startQueryRequest(Map qptArguments) + { + return new StartQueryRequest().withEndTime(Long.valueOf(qptArguments.get(CloudwatchQueryPassthrough.ENDTIME))).withStartTime(Long.valueOf(qptArguments.get(CloudwatchQueryPassthrough.STARTTIME))) + .withQueryString(qptArguments.get(CloudwatchQueryPassthrough.QUERYSTRING)).withLogGroupNames(getLogGroupNames(qptArguments)); + } + + private static String[] getLogGroupNames(Map qptArguments) + { + String[] logGroupNames = qptArguments.get(CloudwatchQueryPassthrough.LOGGROUPNAMES).split(", "); + logger.info("log group names {}", logGroupNames); + for (int i = 0; i < logGroupNames.length; i++) { + logGroupNames[i] = logGroupNames[i].replaceAll("^\"|\"$", ""); + } + return logGroupNames; + } + + public static StartQueryResult getQueryResult(AWSLogs awsLogs, StartQueryRequest startQueryRequest) + { + return awsLogs.startQuery(startQueryRequest); + } + + public static GetQueryResultsResult getQueryResults(AWSLogs awsLogs, StartQueryResult startQueryResult) + { + return awsLogs.getQueryResults(new GetQueryResultsRequest().withQueryId(startQueryResult.getQueryId())); + } + + public static GetQueryResultsResult getResult(ThrottlingInvoker invoker, AWSLogs awsLogs, Map qptArguments, int limit) throws TimeoutException, InterruptedException + { + StartQueryResult startQueryResult = invoker.invoke(() -> getQueryResult(awsLogs, startQueryRequest(qptArguments).withLimit(limit))); + String status = null; + GetQueryResultsResult getQueryResultsResult; + Instant startTime = Instant.now(); // Record the start time + do { + getQueryResultsResult = invoker.invoke(() -> getQueryResults(awsLogs, startQueryResult)); + status = getQueryResultsResult.getStatus(); + Thread.sleep(1000); + + // Check if 10 minutes have passed + Instant currentTime = Instant.now(); + long elapsedMinutes = ChronoUnit.MINUTES.between(startTime, currentTime); + if (elapsedMinutes >= RESULT_TIMEOUT) { + throw new RuntimeException("Query execution timeout exceeded."); + } + } while (!status.equalsIgnoreCase("Complete")); + + return getQueryResultsResult; + } +} diff --git a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/qpt/CloudwatchQueryPassthrough.java b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/qpt/CloudwatchQueryPassthrough.java new file mode 100644 index 0000000000..97d88acfcb --- /dev/null +++ b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/qpt/CloudwatchQueryPassthrough.java @@ -0,0 +1,64 @@ +/*- + * #%L + * athena-docdb + * %% + * Copyright (C) 2019 - 2024 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.cloudwatch.qpt; + +import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.List; + +public final class CloudwatchQueryPassthrough implements QueryPassthroughSignature +{ + private static final String SCHEMA_NAME = "system"; + private static final String NAME = "query"; + // List of arguments for the query, statically initialized as it always contains the same value. + public static final String ENDTIME = "ENDTIME"; + public static final String LIMIT = "LIMIT"; + public static final String LOGGROUPNAMES = "LOGGROUPNAMES"; + public static final String QUERYSTRING = "QUERYSTRING"; + public static final String STARTTIME = "STARTTIME"; + private static final Logger LOGGER = LoggerFactory.getLogger(CloudwatchQueryPassthrough.class); + + @Override + public String getFunctionSchema() + { + return SCHEMA_NAME; + } + + @Override + public String getFunctionName() + { + return NAME; + } + + @Override + public List getFunctionArguments() + { + return Arrays.asList(ENDTIME, QUERYSTRING, STARTTIME, LOGGROUPNAMES, LIMIT); + } + + @Override + public Logger getLogger() + { + return LOGGER; + } +}