Skip to content

Commit

Permalink
Query passthrough changes for Cloudwatch connector (#1906)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trianz-Akshay authored Apr 12, 2024
1 parent 6fd2923 commit cf9ae17
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
import com.amazonaws.athena.connector.lambda.handlers.MetadataHandler;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest;
Expand All @@ -38,16 +40,21 @@
import com.amazonaws.athena.connector.lambda.metadata.ListSchemasResponse;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connectors.cloudwatch.qpt.CloudwatchQueryPassthrough;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.logs.AWSLogs;
import com.amazonaws.services.logs.AWSLogsClientBuilder;
import com.amazonaws.services.logs.model.DescribeLogGroupsRequest;
import com.amazonaws.services.logs.model.DescribeLogGroupsResult;
import com.amazonaws.services.logs.model.DescribeLogStreamsRequest;
import com.amazonaws.services.logs.model.DescribeLogStreamsResult;
import com.amazonaws.services.logs.model.GetQueryResultsResult;
import com.amazonaws.services.logs.model.LogStream;
import com.amazonaws.services.logs.model.ResultField;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.collect.ImmutableMap;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.complex.reader.FieldReader;
import org.apache.arrow.vector.types.Types;
Expand All @@ -60,11 +67,13 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeoutException;

import static com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest.UNLIMITED_PAGE_SIZE_VALUE;
import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchExceptionFilter.EXCEPTION_FILTER;
import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchUtils.getResult;

/**
* Handles metadata requests for the Athena Cloudwatch Connector.
Expand Down Expand Up @@ -117,6 +126,7 @@ public class CloudwatchMetadataHandler
private final AWSLogs awsLogs;
private final ThrottlingInvoker invoker;
private final CloudwatchTableResolver tableResolver;
private final CloudwatchQueryPassthrough queryPassthrough = new CloudwatchQueryPassthrough();

public CloudwatchMetadataHandler(java.util.Map<String, String> configOptions)
{
Expand Down Expand Up @@ -241,6 +251,9 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
@Override
public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTableLayoutRequest request)
{
if (request.getTableName().getQualifiedTableName().equalsIgnoreCase(queryPassthrough.getFunctionSignature())) {
return;
}
partitionSchemaBuilder.addField(LOG_STREAM_SIZE_FIELD, new ArrowType.Int(64, true));
partitionSchemaBuilder.addField(LOG_GROUP_FIELD, Types.MinorType.VARCHAR.getType());
}
Expand All @@ -257,6 +270,10 @@ public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTabl
public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker)
throws Exception
{
if (request.getTableName().getQualifiedTableName().equalsIgnoreCase(queryPassthrough.getFunctionSignature())) {
return;
}

CloudwatchTableName cwTableName = tableResolver.validateTable(request.getTableName());

DescribeLogStreamsRequest cwRequest = new DescribeLogStreamsRequest(cwTableName.getLogGroupName());
Expand Down Expand Up @@ -290,6 +307,15 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request
@Override
public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest request)
{
if (request.getConstraints().isQueryPassThrough()) {
//Since this is QPT query we return a fixed split.
Map<String, String> qptArguments = request.getConstraints().getQueryPassthroughArguments();
return new GetSplitsResponse(request.getCatalogName(),
Split.newBuilder(makeSpillLocation(request), makeEncryptionKey())
.applyProperties(qptArguments)
.build());
}

int partitionContd = decodeContinuationToken(request);
Set<Split> splits = new HashSet<>();
Block partitions = request.getPartitions();
Expand Down Expand Up @@ -325,6 +351,35 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest
return new GetSplitsResponse(request.getCatalogName(), splits, null);
}

@Override
public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request)
{
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);

return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
}

@Override
public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception
{
if (!request.isQueryPassthrough()) {
throw new IllegalArgumentException("No Query passed through [{}]" + request);
}
// to get column names with limit 1
GetQueryResultsResult getQueryResultsResult = getResult(invoker, awsLogs, request.getQueryPassthroughArguments(), 1);
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();
if (!getQueryResultsResult.getResults().isEmpty()) {
for (ResultField field : getQueryResultsResult.getResults().get(0)) {
schemaBuilder.addField(field.getField(), Types.MinorType.VARCHAR.getType());
}
}

return new GetTableResponse(request.getCatalogName(),
request.getTableName(),
schemaBuilder.build());
}

/**
* Used to handle paginated requests.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@
import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet;
import com.amazonaws.athena.connector.lambda.handlers.RecordHandler;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
import com.amazonaws.athena.connectors.cloudwatch.qpt.CloudwatchQueryPassthrough;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.athena.AmazonAthenaClientBuilder;
import com.amazonaws.services.logs.AWSLogs;
import com.amazonaws.services.logs.AWSLogsClientBuilder;
import com.amazonaws.services.logs.model.GetLogEventsRequest;
import com.amazonaws.services.logs.model.GetLogEventsResult;
import com.amazonaws.services.logs.model.GetQueryResultsResult;
import com.amazonaws.services.logs.model.OutputLogEvent;
import com.amazonaws.services.logs.model.ResultField;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
Expand All @@ -46,6 +49,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;

Expand All @@ -54,6 +59,7 @@
import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchMetadataHandler.LOG_MSG_FIELD;
import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchMetadataHandler.LOG_STREAM_FIELD;
import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchMetadataHandler.LOG_TIME_FIELD;
import static com.amazonaws.athena.connectors.cloudwatch.CloudwatchUtils.getResult;

/**
* Handles data read record requests for the Athena Cloudwatch Connector.
Expand All @@ -73,15 +79,16 @@ public class CloudwatchRecordHandler
private final ThrottlingInvoker invoker;
private final AtomicLong count = new AtomicLong(0);
private final AWSLogs awsLogs;
private final CloudwatchQueryPassthrough queryPassthrough = new CloudwatchQueryPassthrough();

public CloudwatchRecordHandler(java.util.Map<String, String> configOptions)
{
this(
AmazonS3ClientBuilder.defaultClient(),
AWSSecretsManagerClientBuilder.defaultClient(),
AmazonAthenaClientBuilder.defaultClient(),
AWSLogsClientBuilder.defaultClient(),
configOptions);
AmazonS3ClientBuilder.defaultClient(),
AWSSecretsManagerClientBuilder.defaultClient(),
AmazonAthenaClientBuilder.defaultClient(),
AWSLogsClientBuilder.defaultClient(),
configOptions);
}

@VisibleForTesting
Expand All @@ -99,54 +106,79 @@ protected CloudwatchRecordHandler(AmazonS3 amazonS3, AWSSecretsManager secretsMa
*/
@Override
protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker)
throws TimeoutException
throws TimeoutException, InterruptedException
{
String continuationToken = null;
TableName tableName = recordsRequest.getTableName();
Split split = recordsRequest.getSplit();
invoker.setBlockSpiller(spiller);
do {
final String actualContinuationToken = continuationToken;
GetLogEventsResult logEventsResult = invoker.invoke(() -> awsLogs.getLogEvents(
pushDownConstraints(recordsRequest.getConstraints(),
new GetLogEventsRequest()
.withLogGroupName(split.getProperty(LOG_GROUP_FIELD))
//We use the property instead of the table name because of the special all_streams table
.withLogStreamName(split.getProperty(LOG_STREAM_FIELD))
.withNextToken(actualContinuationToken)
// must be set to use nextToken correctly
.withStartFromHead(true)
)));

if (continuationToken == null || !continuationToken.equals(logEventsResult.getNextForwardToken())) {
continuationToken = logEventsResult.getNextForwardToken();
}
else {
continuationToken = null;
}
if (recordsRequest.getConstraints().isQueryPassThrough()) {
getQueryPassthreoughResults(spiller, recordsRequest);
}
else {
String continuationToken = null;
TableName tableName = recordsRequest.getTableName();
Split split = recordsRequest.getSplit();
invoker.setBlockSpiller(spiller);
do {
final String actualContinuationToken = continuationToken;
GetLogEventsResult logEventsResult = invoker.invoke(() -> awsLogs.getLogEvents(
pushDownConstraints(recordsRequest.getConstraints(),
new GetLogEventsRequest()
.withLogGroupName(split.getProperty(LOG_GROUP_FIELD))
//We use the property instead of the table name because of the special all_streams table
.withLogStreamName(split.getProperty(LOG_STREAM_FIELD))
.withNextToken(actualContinuationToken)
// must be set to use nextToken correctly
.withStartFromHead(true)
)));

for (OutputLogEvent ole : logEventsResult.getEvents()) {
spiller.writeRows((Block block, int rowNum) -> {
boolean matched = true;
matched &= block.offerValue(LOG_STREAM_FIELD, rowNum, split.getProperty(LOG_STREAM_FIELD));
matched &= block.offerValue(LOG_TIME_FIELD, rowNum, ole.getTimestamp());
matched &= block.offerValue(LOG_MSG_FIELD, rowNum, ole.getMessage());
return matched ? 1 : 0;
});
if (continuationToken == null || !continuationToken.equals(logEventsResult.getNextForwardToken())) {
continuationToken = logEventsResult.getNextForwardToken();
}
else {
continuationToken = null;
}

for (OutputLogEvent ole : logEventsResult.getEvents()) {
spiller.writeRows((Block block, int rowNum) -> {
boolean matched = true;
matched &= block.offerValue(LOG_STREAM_FIELD, rowNum, split.getProperty(LOG_STREAM_FIELD));
matched &= block.offerValue(LOG_TIME_FIELD, rowNum, ole.getTimestamp());
matched &= block.offerValue(LOG_MSG_FIELD, rowNum, ole.getMessage());
return matched ? 1 : 0;
});
}

logger.info("readWithConstraint: LogGroup[{}] LogStream[{}] Continuation[{}] rows[{}]",
tableName.getSchemaName(), tableName.getTableName(), continuationToken,
logEventsResult.getEvents().size());
}
while (continuationToken != null && queryStatusChecker.isQueryRunning());
}
}

logger.info("readWithConstraint: LogGroup[{}] LogStream[{}] Continuation[{}] rows[{}]",
tableName.getSchemaName(), tableName.getTableName(), continuationToken,
logEventsResult.getEvents().size());
private void getQueryPassthreoughResults(BlockSpiller spiller, ReadRecordsRequest recordsRequest) throws TimeoutException, InterruptedException
{
Map<String, String> qptArguments = recordsRequest.getConstraints().getQueryPassthroughArguments();
queryPassthrough.verify(qptArguments);
GetQueryResultsResult getQueryResultsResult = getResult(invoker, awsLogs, qptArguments, Integer.parseInt(qptArguments.get(CloudwatchQueryPassthrough.LIMIT)));

for (List<ResultField> resultList : getQueryResultsResult.getResults()) {
spiller.writeRows((Block block, int rowNum) -> {
for (ResultField resultField : resultList) {
boolean matched = true;
matched &= block.offerValue(resultField.getField(), rowNum, resultField.getValue());
if (!matched) {
return 0;
}
}
return 1;
});
}
while (continuationToken != null && queryStatusChecker.isQueryRunning());
}

/**
* Attempts to push down predicates into Cloudwatch Logs by decorating the Cloudwatch Logs request.
*
* @param constraints The constraints for the read as provided by Athena based on the customer's query.
* @param request The Cloudwatch Logs request to inject predicates to.
* @param request The Cloudwatch Logs request to inject predicates to.
* @return The decorated Cloudwatch Logs request.
* @note This impl currently only pushing down SortedRangeSet filters (>=, =<, between) on the log time column.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*-
* #%L
* athena-cloudwatch
* %%
* Copyright (C) 2019 - 2024 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.cloudwatch;

import com.amazonaws.athena.connector.lambda.ThrottlingInvoker;
import com.amazonaws.athena.connectors.cloudwatch.qpt.CloudwatchQueryPassthrough;
import com.amazonaws.services.logs.AWSLogs;
import com.amazonaws.services.logs.model.GetQueryResultsRequest;
import com.amazonaws.services.logs.model.GetQueryResultsResult;
import com.amazonaws.services.logs.model.StartQueryRequest;
import com.amazonaws.services.logs.model.StartQueryResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Map;
import java.util.concurrent.TimeoutException;

public final class CloudwatchUtils
{
private static final Logger logger = LoggerFactory.getLogger(CloudwatchUtils.class);
public static final int RESULT_TIMEOUT = 10;
private CloudwatchUtils() {}
public static StartQueryRequest startQueryRequest(Map<String, String> qptArguments)
{
return new StartQueryRequest().withEndTime(Long.valueOf(qptArguments.get(CloudwatchQueryPassthrough.ENDTIME))).withStartTime(Long.valueOf(qptArguments.get(CloudwatchQueryPassthrough.STARTTIME)))
.withQueryString(qptArguments.get(CloudwatchQueryPassthrough.QUERYSTRING)).withLogGroupNames(getLogGroupNames(qptArguments));
}

private static String[] getLogGroupNames(Map<String, String> qptArguments)
{
String[] logGroupNames = qptArguments.get(CloudwatchQueryPassthrough.LOGGROUPNAMES).split(", ");
logger.info("log group names {}", logGroupNames);
for (int i = 0; i < logGroupNames.length; i++) {
logGroupNames[i] = logGroupNames[i].replaceAll("^\"|\"$", "");
}
return logGroupNames;
}

public static StartQueryResult getQueryResult(AWSLogs awsLogs, StartQueryRequest startQueryRequest)
{
return awsLogs.startQuery(startQueryRequest);
}

public static GetQueryResultsResult getQueryResults(AWSLogs awsLogs, StartQueryResult startQueryResult)
{
return awsLogs.getQueryResults(new GetQueryResultsRequest().withQueryId(startQueryResult.getQueryId()));
}

public static GetQueryResultsResult getResult(ThrottlingInvoker invoker, AWSLogs awsLogs, Map<String, String> qptArguments, int limit) throws TimeoutException, InterruptedException
{
StartQueryResult startQueryResult = invoker.invoke(() -> getQueryResult(awsLogs, startQueryRequest(qptArguments).withLimit(limit)));
String status = null;
GetQueryResultsResult getQueryResultsResult;
Instant startTime = Instant.now(); // Record the start time
do {
getQueryResultsResult = invoker.invoke(() -> getQueryResults(awsLogs, startQueryResult));
status = getQueryResultsResult.getStatus();
Thread.sleep(1000);

// Check if 10 minutes have passed
Instant currentTime = Instant.now();
long elapsedMinutes = ChronoUnit.MINUTES.between(startTime, currentTime);
if (elapsedMinutes >= RESULT_TIMEOUT) {
throw new RuntimeException("Query execution timeout exceeded.");
}
} while (!status.equalsIgnoreCase("Complete"));

return getQueryResultsResult;
}
}
Loading

0 comments on commit cf9ae17

Please sign in to comment.