Skip to content

Commit

Permalink
oracle casing flag (awslabs#2415)
Browse files Browse the repository at this point in the history
  • Loading branch information
aimethed authored and Jithendar12 committed Dec 2, 2024
1 parent 94c1ce0 commit 14617b4
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.athena.AthenaClient;
Expand All @@ -78,6 +79,7 @@
import java.util.Set;
import java.util.stream.Collectors;

import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DEFAULT_GLUE_CONNECTION;
import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME;
import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.MODULUS_FUNCTION_NAME;
import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.NULLIF_FUNCTION_NAME;
Expand All @@ -89,13 +91,15 @@
public class OracleMetadataHandler
extends JdbcMetadataHandler
{
static final String GET_PARTITIONS_QUERY = "Select DISTINCT PARTITION_NAME FROM USER_TAB_PARTITIONS where table_name= ?";
static final String BLOCK_PARTITION_COLUMN_NAME = "PARTITION_NAME";
static final String GET_PARTITIONS_QUERY = "Select DISTINCT PARTITION_NAME as \"partition_name\" FROM USER_TAB_PARTITIONS where table_name= ?";
static final String BLOCK_PARTITION_COLUMN_NAME = "PARTITION_NAME".toLowerCase();
static final String ALL_PARTITIONS = "0";
static final String PARTITION_COLUMN_NAME = "PARTITION_NAME";
static final String PARTITION_COLUMN_NAME = "PARTITION_NAME".toLowerCase();
static final String CASING_MODE = "casing_mode";
private static final Logger LOGGER = LoggerFactory.getLogger(OracleMetadataHandler.class);
private static final int MAX_SPLITS_PER_REQUEST = 1000_000;
private static final String COLUMN_NAME = "COLUMN_NAME";
private static final String ORACLE_QUOTE_CHARACTER = "\"";

static final String LIST_PAGINATED_TABLES_QUERY = "SELECT TABLE_NAME as \"TABLE_NAME\", OWNER as \"TABLE_SCHEM\" FROM all_tables WHERE owner = ? ORDER BY TABLE_NAME OFFSET ? ROWS FETCH NEXT ? ROWS ONLY";

Expand Down Expand Up @@ -154,15 +158,18 @@ public Schema getPartitionSchema(final String catalogName)
public void getPartitions(final BlockWriter blockWriter, final GetTableLayoutRequest getTableLayoutRequest, QueryStatusChecker queryStatusChecker)
throws Exception
{
LOGGER.debug("{}: Schema {}, table {}", getTableLayoutRequest.getQueryId(), getTableLayoutRequest.getTableName().getSchemaName(),
getTableLayoutRequest.getTableName().getTableName());
LOGGER.debug("{}: Schema {}, table {}", getTableLayoutRequest.getQueryId(), transformString(getTableLayoutRequest.getTableName().getSchemaName(), true),
transformString(getTableLayoutRequest.getTableName().getTableName(), true));
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
List<String> parameters = Arrays.asList(getTableLayoutRequest.getTableName().getTableName().toUpperCase());
List<String> parameters = Arrays.asList(transformString(getTableLayoutRequest.getTableName().getTableName(), true));
//try (Statement statement = connection.createStatement(); ResultSet resultSet = statement.executeQuery(GET_PARTITIONS_QUERY + ))
try (PreparedStatement preparedStatement = new PreparedStatementBuilder().withConnection(connection).withQuery(GET_PARTITIONS_QUERY).withParameters(parameters).build();
ResultSet resultSet = preparedStatement.executeQuery()) {
ResultSet resultSet = preparedStatement.executeQuery()) {
// Return a single partition if no partitions defined
if (!resultSet.next()) {
LOGGER.debug("here");
blockWriter.writeRows((Block block, int rowNum) -> {
LOGGER.debug("Parameters: " + BLOCK_PARTITION_COLUMN_NAME + " " + rowNum + " " + ALL_PARTITIONS);
block.setValue(BLOCK_PARTITION_COLUMN_NAME, rowNum, ALL_PARTITIONS);
LOGGER.info("Adding partition {}", ALL_PARTITIONS);
//we wrote 1 row so we return 1
Expand Down Expand Up @@ -305,7 +312,7 @@ public GetTableResponse doGetTable(final BlockAllocator blockAllocator, final Ge
{
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
Schema partitionSchema = getPartitionSchema(getTableRequest.getCatalogName());
TableName tableName = new TableName(getTableRequest.getTableName().getSchemaName().toUpperCase(), getTableRequest.getTableName().getTableName().toUpperCase());
TableName tableName = new TableName(transformString(getTableRequest.getTableName().getSchemaName(), false), transformString(getTableRequest.getTableName().getTableName(), false));
return new GetTableResponse(getTableRequest.getCatalogName(), tableName, getSchema(connection, tableName, partitionSchema),
partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()));
}
Expand Down Expand Up @@ -357,11 +364,12 @@ private Schema getSchema(Connection jdbcConnection, TableName tableName, Schema
*/
try
(PreparedStatement stmt = connection.prepareStatement("select COLUMN_NAME ,DATA_TYPE from USER_TAB_COLS where table_name =?")) {
stmt.setString(1, tableName.getTableName().toUpperCase());
stmt.setString(1, transformString(tableName.getTableName(), true));
ResultSet dataTypeResultSet = stmt.executeQuery();
while (dataTypeResultSet.next()) {
hashMap.put(dataTypeResultSet.getString(COLUMN_NAME).trim(), dataTypeResultSet.getString("DATA_TYPE").trim());
}
LOGGER.debug("hashMap", hashMap.toString());
while (resultSet.next()) {
ArrowType columnType = JdbcArrowTypeConverter.toArrowType(
resultSet.getInt("DATA_TYPE"),
Expand Down Expand Up @@ -433,4 +441,25 @@ private Schema getSchema(Connection jdbcConnection, TableName tableName, Schema
return schemaBuilder.build();
}
}

/**
* Always adds double quotes around the string
* If the lambda uses a glue connection, return the string as is (lowercased by the trino engine)
* Otherwise uppercase it (the default of oracle)
* @param str
* @param quote
* @return
*/
private String transformString(String str, boolean quote)
{
boolean isGlueConnection = StringUtils.isNotBlank(configOptions.get(DEFAULT_GLUE_CONNECTION));
boolean uppercase = configOptions.getOrDefault(CASING_MODE, isGlueConnection ? "lower" : "upper").toLowerCase().equals("upper");
if (uppercase) {
str = str.toUpperCase();
}
if (quote && !str.contains(ORACLE_QUOTE_CHARACTER)) {
str = ORACLE_QUOTE_CHARACTER + str + ORACLE_QUOTE_CHARACTER;
}
return str;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
public class OracleMetadataHandlerTest
extends TestBase
{
private static final Schema PARTITION_SCHEMA = SchemaBuilder.newBuilder().addField("PARTITION_NAME", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build();
private static final Schema PARTITION_SCHEMA = SchemaBuilder.newBuilder().addField("partition_name", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build();
private DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", ORACLE_NAME,
"oracle://jdbc:oracle:thin:username/password@//127.0.0.1:1521/orcl");
private OracleMetadataHandler oracleMetadataHandler;
Expand Down Expand Up @@ -103,15 +103,15 @@ public void doGetTableLayout()
{
BlockAllocator blockAllocator = new BlockAllocatorImpl();
Constraints constraints = Mockito.mock(Constraints.class);
TableName tableName = new TableName("testSchema", "TESTTABLE");
TableName tableName = new TableName("testSchema", "\"TESTTABLE\"");
Schema partitionSchema = this.oracleMetadataHandler.getPartitionSchema("testCatalogName");
Set<String> partitionCols = partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet());
GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, constraints, partitionSchema, partitionCols);

PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class);
Mockito.when(this.connection.prepareStatement(OracleMetadataHandler.GET_PARTITIONS_QUERY)).thenReturn(preparedStatement);

String[] columns = {"PARTITION_NAME"};
String[] columns = {"PARTITION_NAME".toLowerCase()};
int[] types = {Types.VARCHAR};
Object[][] values = {{"p0"}, {"p1"}};
ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1));
Expand All @@ -127,7 +127,7 @@ public void doGetTableLayout()
for (int i = 0; i < getTableLayoutResponse.getPartitions().getRowCount(); i++) {
expectedValues.add(BlockUtils.rowToString(getTableLayoutResponse.getPartitions(), i));
}
Assert.assertEquals(expectedValues, Arrays.asList("[PARTITION_NAME : p0]", "[PARTITION_NAME : p1]"));
Assert.assertEquals(expectedValues, Arrays.asList("[partition_name : p0]", "[partition_name : p1]"));

SchemaBuilder expectedSchemaBuilder = SchemaBuilder.newBuilder();
expectedSchemaBuilder.addField(FieldBuilder.newBuilder(OracleMetadataHandler.BLOCK_PARTITION_COLUMN_NAME, org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build());
Expand All @@ -144,15 +144,15 @@ public void doGetTableLayoutWithNoPartitions()
{
BlockAllocator blockAllocator = new BlockAllocatorImpl();
Constraints constraints = Mockito.mock(Constraints.class);
TableName tableName = new TableName("testSchema", "TESTTABLE");
TableName tableName = new TableName("testSchema", "\"TESTTABLE\"");
Schema partitionSchema = this.oracleMetadataHandler.getPartitionSchema("testCatalogName");
Set<String> partitionCols = partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet());
GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, constraints, partitionSchema, partitionCols);

PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class);
Mockito.when(this.connection.prepareStatement(OracleMetadataHandler.GET_PARTITIONS_QUERY)).thenReturn(preparedStatement);

String[] columns = {"PARTITION_NAME"};
String[] columns = {"PARTITION_NAME".toLowerCase()};
int[] types = {Types.VARCHAR};
Object[][] values = {{}};
ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1));
Expand All @@ -168,7 +168,7 @@ public void doGetTableLayoutWithNoPartitions()
for (int i = 0; i < getTableLayoutResponse.getPartitions().getRowCount(); i++) {
expectedValues.add(BlockUtils.rowToString(getTableLayoutResponse.getPartitions(), i));
}
Assert.assertEquals(expectedValues, Collections.singletonList("[PARTITION_NAME : 0]"));
Assert.assertEquals(expectedValues, Collections.singletonList("[partition_name : 0]"));

SchemaBuilder expectedSchemaBuilder = SchemaBuilder.newBuilder();
expectedSchemaBuilder.addField(FieldBuilder.newBuilder(OracleMetadataHandler.BLOCK_PARTITION_COLUMN_NAME, org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build());
Expand Down Expand Up @@ -249,7 +249,7 @@ public void doGetSplitsContinuation()
PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class);
Mockito.when(this.connection.prepareStatement(OracleMetadataHandler.GET_PARTITIONS_QUERY)).thenReturn(preparedStatement);

String[] columns = {"PARTITION_NAME"};
String[] columns = {"PARTITION_NAME".toLowerCase()};
int[] types = {Types.VARCHAR};
Object[][] values = {{"p0"}, {"p1"}};
ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1));
Expand All @@ -265,7 +265,7 @@ public void doGetSplitsContinuation()
GetSplitsResponse getSplitsResponse = this.oracleMetadataHandler.doGetSplits(splitBlockAllocator, getSplitsRequest);

Set<Map<String, String>> expectedSplits = new HashSet<>();
expectedSplits.add(Collections.singletonMap("PARTITION_NAME", "p1"));
expectedSplits.add(Collections.singletonMap("PARTITION_NAME".toLowerCase(), "p1"));
Assert.assertEquals(expectedSplits.size(), getSplitsResponse.getSplits().size());
Set<Map<String, String>> actualSplits = getSplitsResponse.getSplits().stream().map(Split::getProperties).collect(Collectors.toSet());
Assert.assertEquals(expectedSplits, actualSplits);
Expand Down

0 comments on commit 14617b4

Please sign in to comment.