Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement docdb pagination #1613

Merged
merged 5 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
import com.amazonaws.services.glue.model.Table;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.mongodb.MongoCommandException;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCursor;
import org.apache.arrow.util.VisibleForTesting;
Expand All @@ -56,6 +54,10 @@

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest.UNLIMITED_PAGE_SIZE_VALUE;

/**
* Handles metadata requests for the Athena DocumentDB Connector.
Expand Down Expand Up @@ -168,22 +170,21 @@ public ListSchemasResponse doListSchemaNames(BlockAllocator blockAllocator, List
public ListTablesResponse doListTables(BlockAllocator blockAllocator, ListTablesRequest request)
{
MongoClient client = getOrCreateConn(request);
Stream<String> tableNames = doListTablesWithCommand(client, request);

try (MongoCursor<String> itr = client.getDatabase(request.getSchemaName()).listCollectionNames().iterator()) {
List<TableName> tables = new ArrayList<>();
while (itr.hasNext()) {
tables.add(new TableName(request.getSchemaName(), itr.next()));
}
int startToken = request.getNextToken() != null ? Integer.parseInt(request.getNextToken()) : 0;
int pageSize = request.getPageSize();
String nextToken = null;

return new ListTablesResponse(request.getCatalogName(), tables, null);
if (pageSize != UNLIMITED_PAGE_SIZE_VALUE) {
logger.info("Starting at token {} w/ page size {}", startToken, pageSize);
tableNames = tableNames.skip(startToken).limit(request.getPageSize());
nextToken = Integer.toString(startToken + pageSize);
}
catch (MongoCommandException mongoCommandException) {
//do this in failed case instead of replace method in case API changes on doc db.
logger.warn("Exception on listCollectionNames on Mongo JAVA client, trying with mongo command line.", mongoCommandException);
List<TableName> tableNames = doListTablesWithCommand(client, request);

return new ListTablesResponse(request.getCatalogName(), tableNames.isEmpty() ? ImmutableList.of() : tableNames, null);
}
List<TableName> paginatedTables = tableNames.map(tableName -> new TableName(request.getSchemaName(), tableName)).collect(Collectors.toList());
logger.info("doListTables returned {} tables. Next token is {}", paginatedTables.size(), nextToken);
return new ListTablesResponse(request.getCatalogName(), paginatedTables, nextToken);
}

/**
Expand All @@ -208,18 +209,15 @@ public ListTablesResponse doListTables(BlockAllocator blockAllocator, ListTables
* @param request
* @return
*/
private List<TableName> doListTablesWithCommand(MongoClient client, ListTablesRequest request)
private Stream<String> doListTablesWithCommand(MongoClient client, ListTablesRequest request)
{
logger.debug("doListTablesWithCommand Start");
List<TableName> tables = new ArrayList<>();
Document document = client.getDatabase(request.getSchemaName()).runCommand(new Document("listCollections", 1).append("nameOnly", true).append("authorizedCollections", true));
List<String> tables = new ArrayList<>();
Document queryDocument = new Document("listCollections", 1).append("nameOnly", true).append("authorizedCollections", true);
Document document = client.getDatabase(request.getSchemaName()).runCommand(queryDocument);

List<Document> list = ((Document) document.get("cursor")).getList("firstBatch", Document.class);
for (Document doc : list) {
tables.add(new TableName(request.getSchemaName(), doc.getString("name")));
}

return tables;
return list.stream().map(doc -> doc.getString("name")).sorted();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -149,13 +150,21 @@ public void doListTables()
tableNames.add("table2");
tableNames.add("table3");

Document tableNamesDocument = new Document("cursor",
new Document("firstBatch",
Arrays.asList(new Document("name", "table1"),
new Document("name", "table2"),
new Document("name", "table3"))));

MongoDatabase mockDatabase = mock(MongoDatabase.class);
when(mockClient.getDatabase(eq(DEFAULT_SCHEMA))).thenReturn(mockDatabase);
when(mockDatabase.listCollectionNames()).thenReturn(StubbingCursor.iterate(tableNames));
when(mockDatabase.runCommand(any())).thenReturn(tableNamesDocument);

ListTablesRequest req = new ListTablesRequest(IDENTITY, QUERY_ID, DEFAULT_CATALOG, DEFAULT_SCHEMA,
null, UNLIMITED_PAGE_SIZE_VALUE);

ListTablesResponse res = handler.doListTables(allocator, req);

logger.info("doListTables - {}", res.getTables());

for (TableName next : res.getTables()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;

import static com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest.UNLIMITED_PAGE_SIZE_VALUE;
import static org.junit.Assert.*;
Expand Down