diff --git a/.github/workflows/auto-approve.yml b/.github/workflows/auto-approve.yml index d86d0abfb7..286f90933f 100644 --- a/.github/workflows/auto-approve.yml +++ b/.github/workflows/auto-approve.yml @@ -19,6 +19,16 @@ jobs: env: PR_URL: ${{github.event.pull_request.html_url}} GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} + - name: Approve patch and minor updates + if: ${{ (steps.dependabot-metadata.outputs.update-type == 'version-update:semver-patch' || steps.dependabot-metadata.outputs.update-type == 'version-update:semver-minor') && + !contains(steps.dependabot-metadata.outputs.new-version, 'preview') && + !contains(steps.dependabot-metadata.outputs.new-version, 'alpha') && + !contains(steps.dependabot-metadata.outputs.new-version, 'beta') && + !contains(steps.dependabot-metadata.outputs.new-version, 'rc') }} + run: gh pr review $PR_URL --approve -b "I'm **approving** this pull request because **it includes a patch or minor update**" + env: + PR_URL: ${{github.event.pull_request.html_url}} + GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} - name: Comment on major updates of non-development dependencies if: ${{steps.dependabot-metadata.outputs.update-type == 'version-update:semver-major'}} run: | diff --git a/.github/workflows/maven_push.yml b/.github/workflows/maven_push.yml index 41d33d2699..297a8cc2c0 100644 --- a/.github/workflows/maven_push.yml +++ b/.github/workflows/maven_push.yml @@ -37,3 +37,12 @@ jobs: - name: Identify any Maven Build changes run: > ! (git status | grep "modified: " ) + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/README.md b/README.md index 78be6fafca..52b1f1e09a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Amazon Athena Query Federation [![Build Status](https://github.com/awslabs/aws-athena-query-federation/workflows/Java%20CI%20Push/badge.svg)](https://github.com/awslabs/aws-athena-query-federation/actions) +[![codecov](https://codecov.io/github/awslabs/aws-athena-query-federation/graph/badge.svg?token=x5Q7jg0yUy)](https://codecov.io/github/awslabs/aws-athena-query-federation) The Amazon Athena Query Federation SDK allows you to customize Amazon Athena with your own code. This enables you to integrate with new data sources, proprietary data formats, or build in new user defined functions. Initially these customizations will be limited to the parts of a query that occur during a TableScan operation but will eventually be expanded to include other parts of the query lifecycle using the same easy to understand interface. diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JDBCUtil.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JDBCUtil.java index dbd9f4d421..71de4ce6f8 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JDBCUtil.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JDBCUtil.java @@ -212,7 +212,7 @@ public static List getTableMetadata(PreparedStatement preparedStateme } } catch (SQLException ex) { - LOGGER.info("Unable to return list of {} from data source!", tableType); + LOGGER.warn("Unable to return list of {} from data source!. Returning Empty list of table", tableType, ex); } return list.build(); } diff --git a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleCaseResolver.java b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleCaseResolver.java new file mode 100644 index 0000000000..ec3e767cf9 --- /dev/null +++ b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleCaseResolver.java @@ -0,0 +1,203 @@ + +/*- + * #%L + * athena-oracle + * %% + * Copyright (C) 2019 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.oracle; + +import com.amazonaws.athena.connector.lambda.domain.TableName; +import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Map; + +import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DEFAULT_GLUE_CONNECTION; + +public class OracleCaseResolver +{ + private static final Logger LOGGER = LoggerFactory.getLogger(OracleCaseResolver.class); + static final String SCHEMA_NAME_QUERY_TEMPLATE = "SELECT DISTINCT OWNER as \"OWNER\" FROM all_tables WHERE lower(OWNER) = ?"; + static final String TABLE_NAME_QUERY_TEMPLATE = "SELECT DISTINCT TABLE_NAME as \"TABLE_NAME\" FROM all_tables WHERE OWNER = ? and lower(TABLE_NAME) = ?"; + static final String SCHEMA_NAME_COLUMN_KEY = "OWNER"; + static final String TABLE_NAME_COLUMN_KEY = "TABLE_NAME"; + + // the environment variable that can be set to specify which casing mode to use + static final String CASING_MODE = "casing_mode"; + + // used for identifying database objects (ex: table names) + private static final String ORACLE_IDENTIFIER_CHARACTER = "\""; + // used in SQL statements for character strings (ex: where OWNER = 'example') + private static final String ORACLE_STRING_LITERAL_CHARACTER = "\'"; + + private OracleCaseResolver() {} + + private enum OracleCasingMode + { + LOWER, // casing mode to lower case everything (glue and trino lower case everything) + UPPER, // casing mode to upper case everything (oracle by default upper cases everything) + CASE_INSENSITIVE_SEARCH // casing mode to perform case insensitive search + } + + public static TableName getAdjustedTableObjectName(final Connection connection, TableName tableName, Map configOptions) + throws SQLException + { + OracleCasingMode casingMode = getCasingMode(configOptions); + switch (casingMode) { + case CASE_INSENSITIVE_SEARCH: + String schemaNameCaseInsensitively = getSchemaNameCaseInsensitively(connection, tableName.getSchemaName()); + String tableNameCaseInsensitively = getTableNameCaseInsensitively(connection, schemaNameCaseInsensitively, tableName.getTableName()); + TableName tableNameResult = new TableName(schemaNameCaseInsensitively, tableNameCaseInsensitively); + LOGGER.info("casing mode is `SEARCH`: performing case insensitive search for TableName object. TableName:{}", tableNameResult); + return tableNameResult; + case UPPER: + TableName upperTableName = new TableName(tableName.getSchemaName().toUpperCase(), tableName.getTableName().toUpperCase()); + LOGGER.info("casing mode is `UPPER`: adjusting casing from input to upper case for TableName object. TableName:{}", upperTableName); + return upperTableName; + case LOWER: + TableName lowerTableName = new TableName(tableName.getSchemaName().toLowerCase(), tableName.getTableName().toLowerCase()); + LOGGER.info("casing mode is `LOWER`: adjusting casing from input to lower case for TableName object. TableName:{}", lowerTableName); + return lowerTableName; + } + LOGGER.warn("casing mode is empty: not adjust casing from input for TableName object. TableName:{}", tableName); + return tableName; + } + + public static String getAdjustedSchemaName(final Connection connection, String schemaNameInput, Map configOptions) + throws SQLException + { + OracleCasingMode casingMode = getCasingMode(configOptions); + switch (casingMode) { + case CASE_INSENSITIVE_SEARCH: + LOGGER.info("casing mode is SEARCH: performing case insensitive search for Schema..."); + return getSchemaNameCaseInsensitively(connection, schemaNameInput); + case UPPER: + LOGGER.info("casing mode is `UPPER`: adjusting casing from input to upper case for Schema"); + return schemaNameInput.toUpperCase(); + case LOWER: + LOGGER.info("casing mode is `LOWER`: adjusting casing from input to lower case for Schema"); + return schemaNameInput.toLowerCase(); + } + + return schemaNameInput; + } + + public static String getSchemaNameCaseInsensitively(final Connection connection, String schemaName) + throws SQLException + { + String nameFromOracle = null; + int i = 0; + try (PreparedStatement preparedStatement = new PreparedStatementBuilder() + .withConnection(connection) + .withQuery(SCHEMA_NAME_QUERY_TEMPLATE) + .withParameters(Arrays.asList(schemaName.toLowerCase())).build(); + ResultSet resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + i++; + String schemaNameCandidate = resultSet.getString(SCHEMA_NAME_COLUMN_KEY); + LOGGER.debug("Case insensitive search on columLabel: {}, schema name: {}", SCHEMA_NAME_COLUMN_KEY, schemaNameCandidate); + nameFromOracle = schemaNameCandidate; + } + } + catch (SQLException e) { + throw new RuntimeException(String.format("getSchemaNameCaseInsensitively query failed for %s", schemaName), e); + } + + if (i != 1) { + throw new RuntimeException(String.format("Schema name case insensitive match failed, number of match : %d", i)); + } + + return nameFromOracle; + } + + public static String getTableNameCaseInsensitively(final Connection connection, String schemaName, String tableNameInput) + throws SQLException + { + // schema name input should be correct case before searching tableName already + String nameFromOracle = null; + int i = 0; + try (PreparedStatement preparedStatement = new PreparedStatementBuilder() + .withConnection(connection) + .withQuery(TABLE_NAME_QUERY_TEMPLATE) + .withParameters(Arrays.asList((schemaName), tableNameInput.toLowerCase())).build(); + ResultSet resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + i++; + String schemaNameCandidate = resultSet.getString(TABLE_NAME_COLUMN_KEY); + LOGGER.debug("Case insensitive search on columLabel: {}, schema name: {}", TABLE_NAME_COLUMN_KEY, schemaNameCandidate); + nameFromOracle = schemaNameCandidate; + } + } + catch (SQLException e) { + throw new RuntimeException(String.format("getTableNameCaseInsensitively query failed for schema: %s tableName: %s", schemaName, tableNameInput), e); + } + + if (i != 1) { + throw new RuntimeException(String.format("Schema name case insensitive match failed, number of match : %d", i)); + } + + return nameFromOracle; + } + + private static OracleCasingMode getCasingMode(Map configOptions) + { + boolean isGlueConnection = StringUtils.isNotBlank(configOptions.get(DEFAULT_GLUE_CONNECTION)); + if (!configOptions.containsKey(CASING_MODE)) { + LOGGER.info("CASING MODE not set"); + return isGlueConnection ? OracleCasingMode.LOWER : OracleCasingMode.UPPER; + } + + try { + OracleCasingMode oracleCasingMode = OracleCasingMode.valueOf(configOptions.get(CASING_MODE).toUpperCase()); + LOGGER.info("CASING MODE enable: {}", oracleCasingMode.toString()); + return oracleCasingMode; + } + catch (IllegalArgumentException ex) { + // print error log for customer along with list of input + LOGGER.error("Invalid input for:{}, input value:{}, valid values:{}", CASING_MODE, configOptions.get(CASING_MODE), Arrays.asList(OracleCasingMode.values()), ex); + throw ex; + } + } + + public static TableName quoteTableName(TableName inputTable) + { + String schemaName = inputTable.getSchemaName(); + String tableName = inputTable.getTableName(); + if (!schemaName.contains(ORACLE_IDENTIFIER_CHARACTER)) { + schemaName = ORACLE_IDENTIFIER_CHARACTER + schemaName + ORACLE_IDENTIFIER_CHARACTER; + } + if (!tableName.contains(ORACLE_IDENTIFIER_CHARACTER)) { + tableName = ORACLE_IDENTIFIER_CHARACTER + tableName + ORACLE_IDENTIFIER_CHARACTER; + } + return new TableName(schemaName, tableName); + } + + public static String convertToLiteral(String input) + { + if (!input.contains(ORACLE_STRING_LITERAL_CHARACTER)) { + input = ORACLE_STRING_LITERAL_CHARACTER + input + ORACLE_STRING_LITERAL_CHARACTER; + } + return input; + } +} diff --git a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandler.java b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandler.java index 0b4c42eeee..f89b037f63 100644 --- a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandler.java +++ b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandler.java @@ -62,7 +62,6 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; @@ -79,7 +78,6 @@ import java.util.Set; import java.util.stream.Collectors; -import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DEFAULT_GLUE_CONNECTION; import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME; import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.MODULUS_FUNCTION_NAME; import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.NULLIF_FUNCTION_NAME; @@ -95,11 +93,9 @@ public class OracleMetadataHandler static final String BLOCK_PARTITION_COLUMN_NAME = "PARTITION_NAME".toLowerCase(); static final String ALL_PARTITIONS = "0"; static final String PARTITION_COLUMN_NAME = "PARTITION_NAME".toLowerCase(); - static final String CASING_MODE = "casing_mode"; private static final Logger LOGGER = LoggerFactory.getLogger(OracleMetadataHandler.class); private static final int MAX_SPLITS_PER_REQUEST = 1000_000; private static final String COLUMN_NAME = "COLUMN_NAME"; - private static final String ORACLE_QUOTE_CHARACTER = "\""; static final String LIST_PAGINATED_TABLES_QUERY = "SELECT TABLE_NAME as \"TABLE_NAME\", OWNER as \"TABLE_SCHEM\" FROM all_tables WHERE owner = ? ORDER BY TABLE_NAME OFFSET ? ROWS FETCH NEXT ? ROWS ONLY"; @@ -158,10 +154,11 @@ public Schema getPartitionSchema(final String catalogName) public void getPartitions(final BlockWriter blockWriter, final GetTableLayoutRequest getTableLayoutRequest, QueryStatusChecker queryStatusChecker) throws Exception { - LOGGER.debug("{}: Schema {}, table {}", getTableLayoutRequest.getQueryId(), transformString(getTableLayoutRequest.getTableName().getSchemaName(), true), - transformString(getTableLayoutRequest.getTableName().getTableName(), true)); try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) { - List parameters = Arrays.asList(transformString(getTableLayoutRequest.getTableName().getTableName(), true)); + TableName casedTableName = getTableLayoutRequest.getTableName(); + LOGGER.debug("{}: Schema {}, table {}", getTableLayoutRequest.getQueryId(), casedTableName.getSchemaName(), + casedTableName.getTableName()); + List parameters = Arrays.asList(OracleCaseResolver.convertToLiteral(casedTableName.getTableName())); try (PreparedStatement preparedStatement = new PreparedStatementBuilder().withConnection(connection).withQuery(GET_PARTITIONS_QUERY).withParameters(parameters).build(); ResultSet resultSet = preparedStatement.executeQuery()) { // Return a single partition if no partitions defined @@ -256,7 +253,8 @@ protected ListTablesResponse listPaginatedTables(final Connection connection, fi int t = token != null ? Integer.parseInt(token) : 0; LOGGER.info("Starting pagination at {} with page size {}", token, pageSize); - List paginatedTables = getPaginatedTables(connection, listTablesRequest.getSchemaName(), t, pageSize); + String casedSchemaName = OracleCaseResolver.getAdjustedSchemaName(connection, listTablesRequest.getSchemaName(), configOptions); + List paginatedTables = getPaginatedTables(connection, casedSchemaName, t, pageSize); LOGGER.info("{} tables returned. Next token is {}", paginatedTables.size(), t + pageSize); return new ListTablesResponse(listTablesRequest.getCatalogName(), paginatedTables, Integer.toString(t + pageSize)); } @@ -310,7 +308,7 @@ public GetTableResponse doGetTable(final BlockAllocator blockAllocator, final Ge { try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) { Schema partitionSchema = getPartitionSchema(getTableRequest.getCatalogName()); - TableName tableName = new TableName(transformString(getTableRequest.getTableName().getSchemaName(), false), transformString(getTableRequest.getTableName().getTableName(), false)); + TableName tableName = OracleCaseResolver.getAdjustedTableObjectName(connection, getTableRequest.getTableName(), configOptions); return new GetTableResponse(getTableRequest.getCatalogName(), tableName, getSchema(connection, tableName, partitionSchema), partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet())); } @@ -413,25 +411,4 @@ private Schema getSchema(Connection jdbcConnection, TableName tableName, Schema return schemaBuilder.build(); } } - - /** - * Always adds double quotes around the string - * If the lambda uses a glue connection, return the string as is (lowercased by the trino engine) - * Otherwise uppercase it (the default of oracle) - * @param str - * @param quote - * @return - */ - private String transformString(String str, boolean quote) - { - boolean isGlueConnection = StringUtils.isNotBlank(configOptions.get(DEFAULT_GLUE_CONNECTION)); - boolean uppercase = configOptions.getOrDefault(CASING_MODE, isGlueConnection ? "lower" : "upper").toLowerCase().equals("upper"); - if (uppercase) { - str = str.toUpperCase(); - } - if (quote && !str.contains(ORACLE_QUOTE_CHARACTER)) { - str = ORACLE_QUOTE_CHARACTER + str + ORACLE_QUOTE_CHARACTER; - } - return str; - } } diff --git a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleCaseResolverTest.java b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleCaseResolverTest.java new file mode 100644 index 0000000000..2b71810011 --- /dev/null +++ b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleCaseResolverTest.java @@ -0,0 +1,323 @@ +/*- + * #%L + * athena-oracle + * %% + * Copyright (C) 2019 - 2022 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.oracle; + +import static org.mockito.ArgumentMatchers.nullable; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.Types; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import com.amazonaws.athena.connector.lambda.domain.TableName; +import com.amazonaws.athena.connectors.jdbc.TestBase; +import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; +import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; + +import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DEFAULT_GLUE_CONNECTION; + +public class OracleCaseResolverTest + extends TestBase +{ + private JdbcConnectionFactory jdbcConnectionFactory; + private Connection connection; + + @Before + public void setup() + throws Exception + { + this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class, Mockito.RETURNS_DEEP_STUBS); + this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); + Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); + } + + @Test + public void getAdjustedTableObjectNameLower() + throws Exception + { + TableName inputTableName = new TableName("testschema", "testtable"); + Map config = Collections.singletonMap(OracleCaseResolver.CASING_MODE, "lower"); + TableName outputTableName = OracleCaseResolver.getAdjustedTableObjectName(connection, inputTableName, config); + Assert.assertEquals(outputTableName.getSchemaName(), inputTableName.getSchemaName()); + Assert.assertEquals(outputTableName.getTableName(), inputTableName.getTableName()); + } + + @Test + public void getAdjustedTableObjectNameUpper() + throws Exception + { + TableName inputTableName = new TableName("testschema", "testtable"); + Map config = Collections.singletonMap(OracleCaseResolver.CASING_MODE, "upper"); + TableName outputTableName = OracleCaseResolver.getAdjustedTableObjectName(connection, inputTableName, config); + Assert.assertEquals(outputTableName.getSchemaName(), inputTableName.getSchemaName().toUpperCase()); + Assert.assertEquals(outputTableName.getTableName(), inputTableName.getTableName().toUpperCase()); + } + + @Test + public void getAdjustedTableObjectNameSearch() + throws Exception + { + String inputSchemaName = "testschema"; + String matchedSchemaName = "testSchema"; + String inputTableName = "testtable"; + String matchedTableName = "TESTTABLE"; + TableName inputTableNameObject = new TableName(inputSchemaName, inputTableName); + Map config = Collections.singletonMap(OracleCaseResolver.CASING_MODE, "case_insensitive_search"); + + PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class); + Mockito.when(this.connection.prepareStatement(OracleCaseResolver.SCHEMA_NAME_QUERY_TEMPLATE)).thenReturn(preparedStatement); + Mockito.when(this.connection.prepareStatement(OracleCaseResolver.TABLE_NAME_QUERY_TEMPLATE)).thenReturn(preparedStatement); + + String[] columnsSchema = {OracleCaseResolver.SCHEMA_NAME_COLUMN_KEY}; + int[] typesSchema = {Types.VARCHAR}; + Object[][] valuesSchema = {{matchedSchemaName}}; + ResultSet resultSetSchema = mockResultSet(columnsSchema, typesSchema, valuesSchema, new AtomicInteger(-1)); + String[] columnsTable = {OracleCaseResolver.TABLE_NAME_COLUMN_KEY}; + int[] typesTable = {Types.VARCHAR}; + Object[][] valuesTable = {{matchedTableName}}; + ResultSet resultSetTable = mockResultSet(columnsTable, typesTable, valuesTable, new AtomicInteger(-1)); + // schema query is first, then table name + Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSetSchema).thenReturn(resultSetTable); + + TableName outputTableName = OracleCaseResolver.getAdjustedTableObjectName(connection, inputTableNameObject, config); + Assert.assertEquals(outputTableName.getSchemaName(), matchedSchemaName); + Assert.assertEquals(outputTableName.getTableName(), matchedTableName); + } + + @Test + public void getAdjustedTableObjectNameNoConfig() + throws Exception + { + TableName inputTableName = new TableName("testschema", "testtable"); + Map config = Collections.emptyMap(); + TableName outputTableName = OracleCaseResolver.getAdjustedTableObjectName(connection, inputTableName, config); + Assert.assertEquals(outputTableName.getSchemaName(), inputTableName.getSchemaName().toUpperCase()); + Assert.assertEquals(outputTableName.getTableName(), inputTableName.getTableName().toUpperCase()); + } + + @Test + public void getAdjustedTableObjectNameGlueConnection() + throws Exception + { + TableName inputTableName = new TableName("testschema", "testtable"); + Map config = Collections.singletonMap(DEFAULT_GLUE_CONNECTION, "notBlank"); + TableName outputTableName = OracleCaseResolver.getAdjustedTableObjectName(connection, inputTableName, config); + Assert.assertEquals(outputTableName.getSchemaName(), inputTableName.getSchemaName()); + Assert.assertEquals(outputTableName.getTableName(), inputTableName.getTableName()); + } + + @Test + public void getAdjustedSchemaNameLower() + throws Exception + { + // the trino engine will lowercase anything + String inputSchemaName = "testschema"; + Map config = Collections.singletonMap(OracleCaseResolver.CASING_MODE, "lower"); + String outputSchemaName = OracleCaseResolver.getAdjustedSchemaName(connection, inputSchemaName, config); + Assert.assertEquals(inputSchemaName, outputSchemaName); + } + + @Test + public void getAdjustedSchemaNameUpper() + throws Exception + { + String inputSchemaName = "testschema"; + Map config = Collections.singletonMap(OracleCaseResolver.CASING_MODE, "upper"); + String outputSchemaName = OracleCaseResolver.getAdjustedSchemaName(connection, inputSchemaName, config); + Assert.assertEquals(inputSchemaName.toUpperCase(), outputSchemaName); + } + + @Test + public void getAdjustedSchemaNameSearch() + throws Exception + { + String inputSchemaName = "testschema"; + String matchedSchemaName = "testSchema"; + Map config = Collections.singletonMap(OracleCaseResolver.CASING_MODE, "case_insensitive_search"); + + PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class); + Mockito.when(this.connection.prepareStatement(OracleCaseResolver.SCHEMA_NAME_QUERY_TEMPLATE)).thenReturn(preparedStatement); + + String[] columns = {OracleCaseResolver.SCHEMA_NAME_COLUMN_KEY}; + int[] types = {Types.VARCHAR}; + Object[][] values = {{matchedSchemaName}}; + ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1)); + Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet); + + String outputSchemaName = OracleCaseResolver.getAdjustedSchemaName(connection, inputSchemaName, config); + Assert.assertEquals(matchedSchemaName, outputSchemaName); + } + + @Test + public void getAdjustedSchemaNameNoConfig() + throws Exception + { + String inputSchemaName = "testschema"; + Map config = Collections.emptyMap(); + String outputSchemaName = OracleCaseResolver.getAdjustedSchemaName(connection, inputSchemaName, config); + Assert.assertEquals(inputSchemaName.toUpperCase(), outputSchemaName); + } + + @Test + public void getAdjustedSchemaNameGlueConnection() + throws Exception + { + // the trino engine will lowercase anything + String inputSchemaName = "testschema"; + Map config = Collections.singletonMap(DEFAULT_GLUE_CONNECTION, "notBlank"); + String outputSchemaName = OracleCaseResolver.getAdjustedSchemaName(connection, inputSchemaName, config); + Assert.assertEquals(inputSchemaName, outputSchemaName); + } + + @Test + public void getSchemaNameCaseInsensitively() + throws Exception + { + String inputSchemaName = "tEsTsChEmA"; + String matchedSchemaName = "testSchema"; + + PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class); + Mockito.when(this.connection.prepareStatement(OracleCaseResolver.SCHEMA_NAME_QUERY_TEMPLATE)).thenReturn(preparedStatement); + + String[] columns = {OracleCaseResolver.SCHEMA_NAME_COLUMN_KEY}; + int[] types = {Types.VARCHAR}; + Object[][] values = {{matchedSchemaName}}; + ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1)); + Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet); + + String outputSchemaName = OracleCaseResolver.getSchemaNameCaseInsensitively(connection, inputSchemaName); + Assert.assertEquals(outputSchemaName, matchedSchemaName); + } + + @Test(expected = RuntimeException.class) + public void getSchemaNameCaseInsensitivelyFailMultipleMatches() + throws Exception + { + String inputSchemaName = "tEsTsChEmA"; + + PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class); + Mockito.when(this.connection.prepareStatement(OracleCaseResolver.SCHEMA_NAME_QUERY_TEMPLATE)).thenReturn(preparedStatement); + + String[] columns = {OracleCaseResolver.SCHEMA_NAME_COLUMN_KEY}; + int[] types = {Types.VARCHAR}; + Object[][] values = {{"testSchema"}, {"TESTSCHEMA"}}; + ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1)); + Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet); + + // should throw exception because 2 matches found + OracleCaseResolver.getSchemaNameCaseInsensitively(connection, inputSchemaName); + } + + @Test(expected = RuntimeException.class) + public void getSchemaNameCaseInsensitivelyFailNoMatches() + throws Exception + { + String inputSchemaName = "tEsTsChEmA"; + String[] matchedSchemaName = {}; + + PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class); + Mockito.when(this.connection.prepareStatement(OracleCaseResolver.SCHEMA_NAME_QUERY_TEMPLATE)).thenReturn(preparedStatement); + + String[] columns = {OracleCaseResolver.SCHEMA_NAME_COLUMN_KEY}; + int[] types = {Types.VARCHAR}; + Object[][] values = {matchedSchemaName}; + ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1)); + Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet); + + // should throw exception because no matches found + OracleCaseResolver.getSchemaNameCaseInsensitively(connection, inputSchemaName); + } + + @Test + public void getTableNameCaseInsensitively() + throws Exception + { + String schemaName = "TestSchema"; + String inputTableName = "tEsTtAbLe"; + String matchedTableName = "TestTable"; + PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class); + Mockito.when(this.connection.prepareStatement(OracleCaseResolver.TABLE_NAME_QUERY_TEMPLATE)).thenReturn(preparedStatement); + + String[] columns = {OracleCaseResolver.TABLE_NAME_COLUMN_KEY}; + int[] types = {Types.VARCHAR}; + Object[][] values = {{matchedTableName}}; + ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1)); + Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet); + + String outputTableName = OracleCaseResolver.getTableNameCaseInsensitively(connection, schemaName, inputTableName); + Assert.assertEquals(outputTableName, matchedTableName); + } + + @Test(expected = RuntimeException.class) + public void getTableNameCaseInsensitivelyFailMultipleMatches() + throws Exception + { + String schemaName = "TestSchema"; + String inputTableName = "tEsTtAbLe"; + PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class); + Mockito.when(this.connection.prepareStatement(OracleCaseResolver.TABLE_NAME_QUERY_TEMPLATE)).thenReturn(preparedStatement); + + String[] columns = {OracleCaseResolver.TABLE_NAME_COLUMN_KEY}; + int[] types = {Types.VARCHAR}; + Object[][] values = {{"testtable"}, {"TESTTABLE"}}; + ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1)); + Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet); + + // should throw exception because 2 matches found + OracleCaseResolver.getTableNameCaseInsensitively(connection, schemaName, inputTableName); + } + + @Test(expected = RuntimeException.class) + public void getTableNameCaseInsensitivelyFailNoMatches() + throws Exception + { + String schemaName = "TestSchema"; + String inputTableName = "tEsTtAbLe"; + String[] matchedTableName = {}; + PreparedStatement preparedStatement = Mockito.mock(PreparedStatement.class); + Mockito.when(this.connection.prepareStatement(OracleCaseResolver.TABLE_NAME_QUERY_TEMPLATE)).thenReturn(preparedStatement); + + String[] columns = {OracleCaseResolver.TABLE_NAME_COLUMN_KEY}; + int[] types = {Types.VARCHAR}; + Object[][] values = {matchedTableName}; + ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1)); + Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet); + + // should throw exception because no matches found + OracleCaseResolver.getTableNameCaseInsensitively(connection, schemaName, inputTableName); + } + + @Test + public void convertToLiteral() + { + String input = "teststring"; + String expectedOutput = "\'teststring\'"; + Assert.assertEquals(expectedOutput, OracleCaseResolver.convertToLiteral(input)); + Assert.assertEquals(expectedOutput, OracleCaseResolver.convertToLiteral(expectedOutput)); + } +} diff --git a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandlerTest.java b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandlerTest.java index 9a4cd4b376..80eace6dfa 100644 --- a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandlerTest.java +++ b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMetadataHandlerTest.java @@ -104,7 +104,7 @@ public void doGetTableLayout() { BlockAllocator blockAllocator = new BlockAllocatorImpl(); Constraints constraints = Mockito.mock(Constraints.class); - TableName tableName = new TableName("testSchema", "\"TESTTABLE\""); + TableName tableName = new TableName("testSchema", "\'TESTTABLE\'"); Schema partitionSchema = this.oracleMetadataHandler.getPartitionSchema("testCatalogName"); Set partitionCols = partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()); GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, constraints, partitionSchema, partitionCols); @@ -145,7 +145,7 @@ public void doGetTableLayoutWithNoPartitions() { BlockAllocator blockAllocator = new BlockAllocatorImpl(); Constraints constraints = Mockito.mock(Constraints.class); - TableName tableName = new TableName("testSchema", "\"TESTTABLE\""); + TableName tableName = new TableName("testSchema", "\'TESTTABLE\'"); Schema partitionSchema = this.oracleMetadataHandler.getPartitionSchema("testCatalogName"); Set partitionCols = partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()); GetTableLayoutRequest getTableLayoutRequest = new GetTableLayoutRequest(this.federatedIdentity, "testQueryId", "testCatalogName", tableName, constraints, partitionSchema, partitionCols); @@ -327,7 +327,7 @@ public void doGetTable() PARTITION_SCHEMA.getFields().forEach(expectedSchemaBuilder::addField); Schema expected = expectedSchemaBuilder.build(); - TableName inputTableName = new TableName("TESTSCHEMA", "TESTTABLE"); + TableName inputTableName = OracleCaseResolver.quoteTableName(new TableName("TESTSCHEMA", "TESTTABLE")); Mockito.when(connection.getMetaData().getColumns("testCatalog", inputTableName.getSchemaName(), inputTableName.getTableName(), null)).thenReturn(resultSet); Mockito.when(connection.getCatalog()).thenReturn("testCatalog");