Skip to content

Commit

Permalink
Merge branch 'master' into sqlserver_table_casingissue
Browse files Browse the repository at this point in the history
  • Loading branch information
ejeffrli authored Jan 9, 2025
2 parents 4288128 + bb00b44 commit 47eb255
Show file tree
Hide file tree
Showing 8 changed files with 557 additions and 34 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/auto-approve.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ jobs:
env:
PR_URL: ${{github.event.pull_request.html_url}}
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
- name: Approve patch and minor updates
if: ${{ (steps.dependabot-metadata.outputs.update-type == 'version-update:semver-patch' || steps.dependabot-metadata.outputs.update-type == 'version-update:semver-minor') &&
!contains(steps.dependabot-metadata.outputs.new-version, 'preview') &&
!contains(steps.dependabot-metadata.outputs.new-version, 'alpha') &&
!contains(steps.dependabot-metadata.outputs.new-version, 'beta') &&
!contains(steps.dependabot-metadata.outputs.new-version, 'rc') }}
run: gh pr review $PR_URL --approve -b "I'm **approving** this pull request because **it includes a patch or minor update**"
env:
PR_URL: ${{github.event.pull_request.html_url}}
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
- name: Comment on major updates of non-development dependencies
if: ${{steps.dependabot-metadata.outputs.update-type == 'version-update:semver-major'}}
run: |
Expand Down
9 changes: 9 additions & 0 deletions .github/workflows/maven_push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,12 @@ jobs:
- name: Identify any Maven Build changes
run: >
! (git status | grep "modified: " )
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Upload test results to Codecov
if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Amazon Athena Query Federation

[![Build Status](https://github.com/awslabs/aws-athena-query-federation/workflows/Java%20CI%20Push/badge.svg)](https://github.com/awslabs/aws-athena-query-federation/actions)
[![codecov](https://codecov.io/github/awslabs/aws-athena-query-federation/graph/badge.svg?token=x5Q7jg0yUy)](https://codecov.io/github/awslabs/aws-athena-query-federation)

The Amazon Athena Query Federation SDK allows you to customize Amazon Athena with your own code. This enables you to integrate with new data sources, proprietary data formats, or build in new user defined functions. Initially these customizations will be limited to the parts of a query that occur during a TableScan operation but will eventually be expanded to include other parts of the query lifecycle using the same easy to understand interface.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ public static List<TableName> getTableMetadata(PreparedStatement preparedStateme
}
}
catch (SQLException ex) {
LOGGER.info("Unable to return list of {} from data source!", tableType);
LOGGER.warn("Unable to return list of {} from data source!. Returning Empty list of table", tableType, ex);
}
return list.build();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@

/*-
* #%L
* athena-oracle
* %%
* Copyright (C) 2019 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.oracle;

import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Map;

import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DEFAULT_GLUE_CONNECTION;

public class OracleCaseResolver
{
private static final Logger LOGGER = LoggerFactory.getLogger(OracleCaseResolver.class);
static final String SCHEMA_NAME_QUERY_TEMPLATE = "SELECT DISTINCT OWNER as \"OWNER\" FROM all_tables WHERE lower(OWNER) = ?";
static final String TABLE_NAME_QUERY_TEMPLATE = "SELECT DISTINCT TABLE_NAME as \"TABLE_NAME\" FROM all_tables WHERE OWNER = ? and lower(TABLE_NAME) = ?";
static final String SCHEMA_NAME_COLUMN_KEY = "OWNER";
static final String TABLE_NAME_COLUMN_KEY = "TABLE_NAME";

// the environment variable that can be set to specify which casing mode to use
static final String CASING_MODE = "casing_mode";

// used for identifying database objects (ex: table names)
private static final String ORACLE_IDENTIFIER_CHARACTER = "\"";
// used in SQL statements for character strings (ex: where OWNER = 'example')
private static final String ORACLE_STRING_LITERAL_CHARACTER = "\'";

private OracleCaseResolver() {}

private enum OracleCasingMode
{
LOWER, // casing mode to lower case everything (glue and trino lower case everything)
UPPER, // casing mode to upper case everything (oracle by default upper cases everything)
CASE_INSENSITIVE_SEARCH // casing mode to perform case insensitive search
}

public static TableName getAdjustedTableObjectName(final Connection connection, TableName tableName, Map<String, String> configOptions)
throws SQLException
{
OracleCasingMode casingMode = getCasingMode(configOptions);
switch (casingMode) {
case CASE_INSENSITIVE_SEARCH:
String schemaNameCaseInsensitively = getSchemaNameCaseInsensitively(connection, tableName.getSchemaName());
String tableNameCaseInsensitively = getTableNameCaseInsensitively(connection, schemaNameCaseInsensitively, tableName.getTableName());
TableName tableNameResult = new TableName(schemaNameCaseInsensitively, tableNameCaseInsensitively);
LOGGER.info("casing mode is `SEARCH`: performing case insensitive search for TableName object. TableName:{}", tableNameResult);
return tableNameResult;
case UPPER:
TableName upperTableName = new TableName(tableName.getSchemaName().toUpperCase(), tableName.getTableName().toUpperCase());
LOGGER.info("casing mode is `UPPER`: adjusting casing from input to upper case for TableName object. TableName:{}", upperTableName);
return upperTableName;
case LOWER:
TableName lowerTableName = new TableName(tableName.getSchemaName().toLowerCase(), tableName.getTableName().toLowerCase());
LOGGER.info("casing mode is `LOWER`: adjusting casing from input to lower case for TableName object. TableName:{}", lowerTableName);
return lowerTableName;
}
LOGGER.warn("casing mode is empty: not adjust casing from input for TableName object. TableName:{}", tableName);
return tableName;
}

public static String getAdjustedSchemaName(final Connection connection, String schemaNameInput, Map<String, String> configOptions)
throws SQLException
{
OracleCasingMode casingMode = getCasingMode(configOptions);
switch (casingMode) {
case CASE_INSENSITIVE_SEARCH:
LOGGER.info("casing mode is SEARCH: performing case insensitive search for Schema...");
return getSchemaNameCaseInsensitively(connection, schemaNameInput);
case UPPER:
LOGGER.info("casing mode is `UPPER`: adjusting casing from input to upper case for Schema");
return schemaNameInput.toUpperCase();
case LOWER:
LOGGER.info("casing mode is `LOWER`: adjusting casing from input to lower case for Schema");
return schemaNameInput.toLowerCase();
}

return schemaNameInput;
}

public static String getSchemaNameCaseInsensitively(final Connection connection, String schemaName)
throws SQLException
{
String nameFromOracle = null;
int i = 0;
try (PreparedStatement preparedStatement = new PreparedStatementBuilder()
.withConnection(connection)
.withQuery(SCHEMA_NAME_QUERY_TEMPLATE)
.withParameters(Arrays.asList(schemaName.toLowerCase())).build();
ResultSet resultSet = preparedStatement.executeQuery()) {
while (resultSet.next()) {
i++;
String schemaNameCandidate = resultSet.getString(SCHEMA_NAME_COLUMN_KEY);
LOGGER.debug("Case insensitive search on columLabel: {}, schema name: {}", SCHEMA_NAME_COLUMN_KEY, schemaNameCandidate);
nameFromOracle = schemaNameCandidate;
}
}
catch (SQLException e) {
throw new RuntimeException(String.format("getSchemaNameCaseInsensitively query failed for %s", schemaName), e);
}

if (i != 1) {
throw new RuntimeException(String.format("Schema name case insensitive match failed, number of match : %d", i));
}

return nameFromOracle;
}

public static String getTableNameCaseInsensitively(final Connection connection, String schemaName, String tableNameInput)
throws SQLException
{
// schema name input should be correct case before searching tableName already
String nameFromOracle = null;
int i = 0;
try (PreparedStatement preparedStatement = new PreparedStatementBuilder()
.withConnection(connection)
.withQuery(TABLE_NAME_QUERY_TEMPLATE)
.withParameters(Arrays.asList((schemaName), tableNameInput.toLowerCase())).build();
ResultSet resultSet = preparedStatement.executeQuery()) {
while (resultSet.next()) {
i++;
String schemaNameCandidate = resultSet.getString(TABLE_NAME_COLUMN_KEY);
LOGGER.debug("Case insensitive search on columLabel: {}, schema name: {}", TABLE_NAME_COLUMN_KEY, schemaNameCandidate);
nameFromOracle = schemaNameCandidate;
}
}
catch (SQLException e) {
throw new RuntimeException(String.format("getTableNameCaseInsensitively query failed for schema: %s tableName: %s", schemaName, tableNameInput), e);
}

if (i != 1) {
throw new RuntimeException(String.format("Schema name case insensitive match failed, number of match : %d", i));
}

return nameFromOracle;
}

private static OracleCasingMode getCasingMode(Map<String, String> configOptions)
{
boolean isGlueConnection = StringUtils.isNotBlank(configOptions.get(DEFAULT_GLUE_CONNECTION));
if (!configOptions.containsKey(CASING_MODE)) {
LOGGER.info("CASING MODE not set");
return isGlueConnection ? OracleCasingMode.LOWER : OracleCasingMode.UPPER;
}

try {
OracleCasingMode oracleCasingMode = OracleCasingMode.valueOf(configOptions.get(CASING_MODE).toUpperCase());
LOGGER.info("CASING MODE enable: {}", oracleCasingMode.toString());
return oracleCasingMode;
}
catch (IllegalArgumentException ex) {
// print error log for customer along with list of input
LOGGER.error("Invalid input for:{}, input value:{}, valid values:{}", CASING_MODE, configOptions.get(CASING_MODE), Arrays.asList(OracleCasingMode.values()), ex);
throw ex;
}
}

public static TableName quoteTableName(TableName inputTable)
{
String schemaName = inputTable.getSchemaName();
String tableName = inputTable.getTableName();
if (!schemaName.contains(ORACLE_IDENTIFIER_CHARACTER)) {
schemaName = ORACLE_IDENTIFIER_CHARACTER + schemaName + ORACLE_IDENTIFIER_CHARACTER;
}
if (!tableName.contains(ORACLE_IDENTIFIER_CHARACTER)) {
tableName = ORACLE_IDENTIFIER_CHARACTER + tableName + ORACLE_IDENTIFIER_CHARACTER;
}
return new TableName(schemaName, tableName);
}

public static String convertToLiteral(String input)
{
if (!input.contains(ORACLE_STRING_LITERAL_CHARACTER)) {
input = ORACLE_STRING_LITERAL_CHARACTER + input + ORACLE_STRING_LITERAL_CHARACTER;
}
return input;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.athena.AthenaClient;
Expand All @@ -79,7 +78,6 @@
import java.util.Set;
import java.util.stream.Collectors;

import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DEFAULT_GLUE_CONNECTION;
import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME;
import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.MODULUS_FUNCTION_NAME;
import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.NULLIF_FUNCTION_NAME;
Expand All @@ -95,11 +93,9 @@ public class OracleMetadataHandler
static final String BLOCK_PARTITION_COLUMN_NAME = "PARTITION_NAME".toLowerCase();
static final String ALL_PARTITIONS = "0";
static final String PARTITION_COLUMN_NAME = "PARTITION_NAME".toLowerCase();
static final String CASING_MODE = "casing_mode";
private static final Logger LOGGER = LoggerFactory.getLogger(OracleMetadataHandler.class);
private static final int MAX_SPLITS_PER_REQUEST = 1000_000;
private static final String COLUMN_NAME = "COLUMN_NAME";
private static final String ORACLE_QUOTE_CHARACTER = "\"";

static final String LIST_PAGINATED_TABLES_QUERY = "SELECT TABLE_NAME as \"TABLE_NAME\", OWNER as \"TABLE_SCHEM\" FROM all_tables WHERE owner = ? ORDER BY TABLE_NAME OFFSET ? ROWS FETCH NEXT ? ROWS ONLY";

Expand Down Expand Up @@ -158,10 +154,11 @@ public Schema getPartitionSchema(final String catalogName)
public void getPartitions(final BlockWriter blockWriter, final GetTableLayoutRequest getTableLayoutRequest, QueryStatusChecker queryStatusChecker)
throws Exception
{
LOGGER.debug("{}: Schema {}, table {}", getTableLayoutRequest.getQueryId(), transformString(getTableLayoutRequest.getTableName().getSchemaName(), true),
transformString(getTableLayoutRequest.getTableName().getTableName(), true));
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
List<String> parameters = Arrays.asList(transformString(getTableLayoutRequest.getTableName().getTableName(), true));
TableName casedTableName = getTableLayoutRequest.getTableName();
LOGGER.debug("{}: Schema {}, table {}", getTableLayoutRequest.getQueryId(), casedTableName.getSchemaName(),
casedTableName.getTableName());
List<String> parameters = Arrays.asList(OracleCaseResolver.convertToLiteral(casedTableName.getTableName()));
try (PreparedStatement preparedStatement = new PreparedStatementBuilder().withConnection(connection).withQuery(GET_PARTITIONS_QUERY).withParameters(parameters).build();
ResultSet resultSet = preparedStatement.executeQuery()) {
// Return a single partition if no partitions defined
Expand Down Expand Up @@ -256,7 +253,8 @@ protected ListTablesResponse listPaginatedTables(final Connection connection, fi
int t = token != null ? Integer.parseInt(token) : 0;

LOGGER.info("Starting pagination at {} with page size {}", token, pageSize);
List<TableName> paginatedTables = getPaginatedTables(connection, listTablesRequest.getSchemaName(), t, pageSize);
String casedSchemaName = OracleCaseResolver.getAdjustedSchemaName(connection, listTablesRequest.getSchemaName(), configOptions);
List<TableName> paginatedTables = getPaginatedTables(connection, casedSchemaName, t, pageSize);
LOGGER.info("{} tables returned. Next token is {}", paginatedTables.size(), t + pageSize);
return new ListTablesResponse(listTablesRequest.getCatalogName(), paginatedTables, Integer.toString(t + pageSize));
}
Expand Down Expand Up @@ -310,7 +308,7 @@ public GetTableResponse doGetTable(final BlockAllocator blockAllocator, final Ge
{
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
Schema partitionSchema = getPartitionSchema(getTableRequest.getCatalogName());
TableName tableName = new TableName(transformString(getTableRequest.getTableName().getSchemaName(), false), transformString(getTableRequest.getTableName().getTableName(), false));
TableName tableName = OracleCaseResolver.getAdjustedTableObjectName(connection, getTableRequest.getTableName(), configOptions);
return new GetTableResponse(getTableRequest.getCatalogName(), tableName, getSchema(connection, tableName, partitionSchema),
partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()));
}
Expand Down Expand Up @@ -413,25 +411,4 @@ private Schema getSchema(Connection jdbcConnection, TableName tableName, Schema
return schemaBuilder.build();
}
}

/**
* Always adds double quotes around the string
* If the lambda uses a glue connection, return the string as is (lowercased by the trino engine)
* Otherwise uppercase it (the default of oracle)
* @param str
* @param quote
* @return
*/
private String transformString(String str, boolean quote)
{
boolean isGlueConnection = StringUtils.isNotBlank(configOptions.get(DEFAULT_GLUE_CONNECTION));
boolean uppercase = configOptions.getOrDefault(CASING_MODE, isGlueConnection ? "lower" : "upper").toLowerCase().equals("upper");
if (uppercase) {
str = str.toUpperCase();
}
if (quote && !str.contains(ORACLE_QUOTE_CHARACTER)) {
str = ORACLE_QUOTE_CHARACTER + str + ORACLE_QUOTE_CHARACTER;
}
return str;
}
}
Loading

0 comments on commit 47eb255

Please sign in to comment.