diff --git a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBCaseInsensitiveResolver.java b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBCaseInsensitiveResolver.java new file mode 100644 index 0000000000..cc82c01736 --- /dev/null +++ b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBCaseInsensitiveResolver.java @@ -0,0 +1,88 @@ +/*- + * #%L + * athena-mongodb + * %% + * Copyright (C) 2019 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.docdb; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoDatabase; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +public class DocDBCaseInsensitiveResolver +{ + private static final Logger logger = LoggerFactory.getLogger(DocDBCaseInsensitiveResolver.class); + + // This enable_case_insensitive match for schema and table name due to Athena lower case schema and table name. + // Capital letters are permitted for DocumentDB. + private static final String ENABLE_CASE_INSENSITIVE_MATCH = "enable_case_insensitive_match"; + + private DocDBCaseInsensitiveResolver() {} + + public static String getSchemaNameCaseInsensitiveMatch(Map configOptions, MongoClient client, String unresolvedSchemaName) + { + String resolvedSchemaName = unresolvedSchemaName; + if (isCaseInsensitiveMatchEnable(configOptions)) { + logger.info("CaseInsensitiveMatch enable, SchemaName input: {}", resolvedSchemaName); + List candidateSchemaNames = StreamSupport.stream(client.listDatabaseNames().spliterator(), false) + .filter(candidateSchemaName -> candidateSchemaName.equalsIgnoreCase(unresolvedSchemaName)) + .collect(Collectors.toList()); + if (candidateSchemaNames.size() != 1) { + throw new IllegalArgumentException(String.format("Schema name is empty or more than 1 for case insensitive match. schemaName: %s, size: %d", + unresolvedSchemaName, candidateSchemaNames.size())); + } + resolvedSchemaName = candidateSchemaNames.get(0); + logger.info("CaseInsensitiveMatch, SchemaName resolved to: {}", resolvedSchemaName); + } + + return resolvedSchemaName; + } + + public static String getTableNameCaseInsensitiveMatch(Map configOptions, MongoDatabase mongoDatabase, String unresolvedTableName) + { + String resolvedTableName = unresolvedTableName; + if (isCaseInsensitiveMatchEnable(configOptions)) { + logger.info("CaseInsensitiveMatch enable, TableName input: {}", resolvedTableName); + List candidateTableNames = StreamSupport.stream(mongoDatabase.listCollectionNames().spliterator(), false) + .filter(candidateTableName -> candidateTableName.equalsIgnoreCase(unresolvedTableName)) + .collect(Collectors.toList()); + if (candidateTableNames.size() != 1) { + throw new IllegalArgumentException(String.format("Table name is empty or more than 1 for case insensitive match. schemaName: %s, size: %d", + unresolvedTableName, candidateTableNames.size())); + } + resolvedTableName = candidateTableNames.get(0); + logger.info("CaseInsensitiveMatch, TableName resolved to: {}", resolvedTableName); + } + return resolvedTableName; + } + + private static boolean isCaseInsensitiveMatchEnable(Map configOptions) + { + String enableCaseInsensitiveMatchEnvValue = configOptions.getOrDefault(ENABLE_CASE_INSENSITIVE_MATCH, "false").toLowerCase(); + boolean enableCaseInsensitiveMatch = enableCaseInsensitiveMatchEnvValue.equals("true"); + logger.info("{} environment variable set to: {}. Resolved to: {}", + ENABLE_CASE_INSENSITIVE_MATCH, enableCaseInsensitiveMatchEnvValue, enableCaseInsensitiveMatch); + + return enableCaseInsensitiveMatch; + } +} diff --git a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandler.java b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandler.java index 2e63346dd9..2026d6cf2e 100644 --- a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandler.java +++ b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandler.java @@ -45,6 +45,7 @@ import com.google.common.base.Strings; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCursor; +import com.mongodb.client.MongoDatabase; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -233,6 +234,9 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques throws Exception { logger.info("doGetTable: enter", request.getTableName()); + String schemaNameInput = request.getTableName().getSchemaName(); + String tableNameInput = request.getTableName().getTableName(); + TableName tableName = new TableName(schemaNameInput, tableNameInput); Schema schema = null; try { if (glue != null) { @@ -250,9 +254,14 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques if (schema == null) { logger.info("doGetTable: Inferring schema for table[{}].", request.getTableName()); MongoClient client = getOrCreateConn(request); - schema = SchemaUtils.inferSchema(client, request.getTableName(), SCHEMA_INFERRENCE_NUM_DOCS); + //Attempt to update schema and table name with case insensitive match if enable + schemaNameInput = DocDBCaseInsensitiveResolver.getSchemaNameCaseInsensitiveMatch(configOptions, client, schemaNameInput); + MongoDatabase db = client.getDatabase(schemaNameInput); + tableNameInput = DocDBCaseInsensitiveResolver.getTableNameCaseInsensitiveMatch(configOptions, db, tableNameInput); + tableName = new TableName(schemaNameInput, tableNameInput); + schema = SchemaUtils.inferSchema(db, tableName, SCHEMA_INFERRENCE_NUM_DOCS); } - return new GetTableResponse(request.getCatalogName(), request.getTableName(), schema); + return new GetTableResponse(request.getCatalogName(), tableName, schema); } /** diff --git a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/SchemaUtils.java b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/SchemaUtils.java index 40fe3fb5fe..731a7e1e87 100644 --- a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/SchemaUtils.java +++ b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/SchemaUtils.java @@ -22,7 +22,6 @@ import com.amazonaws.athena.connector.lambda.data.FieldBuilder; import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; import com.amazonaws.athena.connector.lambda.domain.TableName; -import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCursor; import com.mongodb.client.MongoDatabase; import org.apache.arrow.vector.types.Types; @@ -70,9 +69,8 @@ private SchemaUtils() {} * to use a reasonable default (like String) and coerce heterogeneous fields to avoid query failure but forcing * explicit handling by defining Schema in AWS Glue is likely a better approach. */ - public static Schema inferSchema(MongoClient client, TableName table, int numObjToSample) + public static Schema inferSchema(MongoDatabase db, TableName table, int numObjToSample) { - MongoDatabase db = client.getDatabase(table.getSchemaName()); int docCount = 0; int fieldCount = 0; try (MongoCursor docs = db.getCollection(table.getTableName()).find().batchSize(numObjToSample) diff --git a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandlerTest.java b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandlerTest.java index a72854fe50..dd16740da2 100644 --- a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandlerTest.java +++ b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandlerTest.java @@ -42,10 +42,12 @@ import com.amazonaws.services.athena.AmazonAthena; import com.amazonaws.services.glue.AWSGlue; import com.amazonaws.services.secretsmanager.AWSSecretsManager; +import com.google.common.collect.ImmutableList; import com.mongodb.client.FindIterable; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoDatabase; +import com.mongodb.client.MongoIterable; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -76,6 +78,7 @@ import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -250,6 +253,175 @@ public void doGetTable() assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(unsupported.getType())); } + @Test + public void doGetTableCaseInsensitiveMatch() + throws Exception + { + + DocDBMetadataHandler caseInsensitiveHandler = new DocDBMetadataHandler(awsGlue, + connectionFactory, new LocalKeyFactory(), secretsManager, mockAthena, + "spillBucket", "spillPrefix", com.google.common.collect.ImmutableMap.of("enable_case_insensitive_match", "true")); + List documents = new ArrayList<>(); + + Document doc1 = new Document(); + documents.add(doc1); + doc1.put("stringCol", "stringVal"); + doc1.put("intCol", 1); + doc1.put("doubleCol", 2.2D); + doc1.put("longCol", 100L); + doc1.put("unsupported", new UnsupportedType()); + + Document doc2 = new Document(); + documents.add(doc2); + doc2.put("stringCol2", "stringVal"); + doc2.put("intCol2", 1); + doc2.put("doubleCol2", 2.2D); + doc2.put("longCol2", 100L); + + Document doc3 = new Document(); + documents.add(doc3); + doc3.put("stringCol", "stringVal"); + doc3.put("intCol2", 1); + doc3.put("doubleCol", 2.2D); + doc3.put("longCol2", 100L); + + MongoDatabase mockDatabase = mock(MongoDatabase.class); + MongoCollection mockCollection = mock(MongoCollection.class); + FindIterable mockIterable = mock(FindIterable.class); + + MongoIterable mockListDatabaseNamesIterable = mock(MongoIterable.class); + when(mockClient.listDatabaseNames()).thenReturn(mockListDatabaseNamesIterable); + + when(mockListDatabaseNamesIterable.spliterator()).thenReturn(ImmutableList.of(DEFAULT_SCHEMA).spliterator()); + + MongoIterable mockListCollectionsNamesIterable = mock(MongoIterable.class); + when(mockDatabase.listCollectionNames()).thenReturn(mockListCollectionsNamesIterable); + when(mockListCollectionsNamesIterable.spliterator()).thenReturn(ImmutableList.of(TEST_TABLE).spliterator()); + + when(mockClient.getDatabase(eq(DEFAULT_SCHEMA))).thenReturn(mockDatabase); + when(mockDatabase.getCollection(eq(TEST_TABLE))).thenReturn(mockCollection); + when(mockCollection.find()).thenReturn(mockIterable); + when(mockIterable.limit(anyInt())).thenReturn(mockIterable); + Mockito.lenient().when(mockIterable.maxScan(anyInt())).thenReturn(mockIterable); + when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable); + when(mockIterable.iterator()).thenReturn(new StubbingCursor(documents.iterator())); + + TableName tableNameInput = new TableName("DEfault", TEST_TABLE.toUpperCase()); + GetTableRequest req = new GetTableRequest(IDENTITY, QUERY_ID, DEFAULT_CATALOG, tableNameInput); + GetTableResponse res = caseInsensitiveHandler.doGetTable(allocator, req); + + assertEquals(DEFAULT_SCHEMA, res.getTableName().getSchemaName()); + assertEquals(TEST_TABLE, res.getTableName().getTableName()); + logger.info("doGetTable - {}", res); + + assertEquals(9, res.getSchema().getFields().size()); + + Field stringCol = res.getSchema().findField("stringCol"); + assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(stringCol.getType())); + + Field stringCol2 = res.getSchema().findField("stringCol2"); + assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(stringCol2.getType())); + + Field intCol = res.getSchema().findField("intCol"); + assertEquals(Types.MinorType.INT, Types.getMinorTypeForArrowType(intCol.getType())); + + Field intCol2 = res.getSchema().findField("intCol2"); + assertEquals(Types.MinorType.INT, Types.getMinorTypeForArrowType(intCol2.getType())); + + Field doubleCol = res.getSchema().findField("doubleCol"); + assertEquals(Types.MinorType.FLOAT8, Types.getMinorTypeForArrowType(doubleCol.getType())); + + Field doubleCol2 = res.getSchema().findField("doubleCol2"); + assertEquals(Types.MinorType.FLOAT8, Types.getMinorTypeForArrowType(doubleCol2.getType())); + + Field longCol = res.getSchema().findField("longCol"); + assertEquals(Types.MinorType.BIGINT, Types.getMinorTypeForArrowType(longCol.getType())); + + Field longCol2 = res.getSchema().findField("longCol2"); + assertEquals(Types.MinorType.BIGINT, Types.getMinorTypeForArrowType(longCol2.getType())); + + Field unsupported = res.getSchema().findField("unsupported"); + assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(unsupported.getType())); + } + + + @Test + public void doGetTableCaseInsensitiveMatchMultipleMatch() + throws Exception + { + + DocDBMetadataHandler caseInsensitiveHandler = new DocDBMetadataHandler(awsGlue, + connectionFactory, new LocalKeyFactory(), secretsManager, mockAthena, + "spillBucket", "spillPrefix", com.google.common.collect.ImmutableMap.of("enable_case_insensitive_match", "true")); + + MongoIterable mockListDatabaseNamesIterable = mock(MongoIterable.class); + when(mockClient.listDatabaseNames()).thenReturn(mockListDatabaseNamesIterable); + when(mockListDatabaseNamesIterable.spliterator()).thenReturn(ImmutableList.of(DEFAULT_SCHEMA, DEFAULT_SCHEMA.toUpperCase()).spliterator()); + + TableName tableNameInput = new TableName("deFAULT", TEST_TABLE.toUpperCase()); + GetTableRequest req = new GetTableRequest(IDENTITY, QUERY_ID, DEFAULT_CATALOG, tableNameInput); + try { + GetTableResponse res = caseInsensitiveHandler.doGetTable(allocator, req); + fail("doGetTableCaseInsensitiveMatchMultipleMatch should failed"); + } catch(IllegalArgumentException ex){ + assertEquals("Schema name is empty or more than 1 for case insensitive match. schemaName: deFAULT, size: 2", ex.getMessage()); + } + } + + @Test + public void doGetTableCaseInsensitiveMatchNotEnable() + throws Exception + { + + String mixedCaseSchemaName = "deFAULT"; + String mixedCaseTableName = "tesT_Table"; + List documents = new ArrayList<>(); + + Document doc1 = new Document(); + documents.add(doc1); + doc1.put("stringCol", "stringVal"); + doc1.put("intCol", 1); + doc1.put("doubleCol", 2.2D); + doc1.put("longCol", 100L); + doc1.put("unsupported", new UnsupportedType()); + + Document doc2 = new Document(); + documents.add(doc2); + doc2.put("stringCol2", "stringVal"); + doc2.put("intCol2", 1); + doc2.put("doubleCol2", 2.2D); + doc2.put("longCol2", 100L); + + Document doc3 = new Document(); + documents.add(doc3); + doc3.put("stringCol", "stringVal"); + doc3.put("intCol2", 1); + doc3.put("doubleCol", 2.2D); + doc3.put("longCol2", 100L); + + MongoDatabase mockDatabase = mock(MongoDatabase.class); + MongoCollection mockCollection = mock(MongoCollection.class); + FindIterable mockIterable = mock(FindIterable.class); + when(mockClient.getDatabase(eq(mixedCaseSchemaName))).thenReturn(mockDatabase); + when(mockDatabase.getCollection(eq(mixedCaseTableName))).thenReturn(mockCollection); + when(mockCollection.find()).thenReturn(mockIterable); + when(mockIterable.limit(anyInt())).thenReturn(mockIterable); + Mockito.lenient().when(mockIterable.maxScan(anyInt())).thenReturn(mockIterable); + when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable); + when(mockIterable.iterator()).thenReturn(new StubbingCursor(documents.iterator())); + + TableName tableNameInput = new TableName(mixedCaseSchemaName, mixedCaseTableName); + GetTableRequest req = new GetTableRequest(IDENTITY, QUERY_ID, DEFAULT_CATALOG, tableNameInput); + GetTableResponse res = handler.doGetTable(allocator, req); + + assertEquals(mixedCaseSchemaName, res.getTableName().getSchemaName()); + assertEquals(mixedCaseTableName, res.getTableName().getTableName()); + + verify(mockClient, Mockito.never()).listDatabaseNames(); + verify(mockDatabase, Mockito.never()).listCollectionNames(); + + } + @Test public void doGetTableLayout() throws Exception diff --git a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/SchemaUtilsTest.java b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/SchemaUtilsTest.java index 61bf31427d..a87af1e9b5 100644 --- a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/SchemaUtilsTest.java +++ b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/SchemaUtilsTest.java @@ -20,6 +20,7 @@ package com.amazonaws.athena.connectors.docdb; import com.amazonaws.athena.connector.lambda.domain.TableName; +import com.google.common.collect.ImmutableMap; import com.mongodb.client.FindIterable; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; @@ -52,11 +53,9 @@ public void UnsupportedTypeTest() unsupported.put("unsupported_col1", new UnsupportedType()); docs.add(unsupported); - MongoClient mockClient = mock(MongoClient.class); MongoDatabase mockDatabase = mock(MongoDatabase.class); MongoCollection mockCollection = mock(MongoCollection.class); FindIterable mockIterable = mock(FindIterable.class); - when(mockClient.getDatabase(any())).thenReturn(mockDatabase); when(mockDatabase.getCollection(any())).thenReturn(mockCollection); when(mockCollection.find()).thenReturn(mockIterable); when(mockIterable.limit(anyInt())).thenReturn(mockIterable); @@ -64,7 +63,7 @@ public void UnsupportedTypeTest() when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable); when(mockIterable.iterator()).thenReturn(new StubbingCursor(docs.iterator())); - Schema schema = SchemaUtils.inferSchema(mockClient, new TableName("test", "test"), 10); + Schema schema = SchemaUtils.inferSchema(mockDatabase, new TableName("test", "test"), 10); assertEquals(1, schema.getFields().size()); Map fields = new HashMap<>(); @@ -115,11 +114,9 @@ public void basicMergeTest() doc3.put("col5", list); docs.add(doc3); - MongoClient mockClient = mock(MongoClient.class); MongoDatabase mockDatabase = mock(MongoDatabase.class); MongoCollection mockCollection = mock(MongoCollection.class); FindIterable mockIterable = mock(FindIterable.class); - when(mockClient.getDatabase(any())).thenReturn(mockDatabase); when(mockDatabase.getCollection(any())).thenReturn(mockCollection); when(mockCollection.find()).thenReturn(mockIterable); when(mockIterable.limit(anyInt())).thenReturn(mockIterable); @@ -127,7 +124,7 @@ public void basicMergeTest() when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable); when(mockIterable.iterator()).thenReturn(new StubbingCursor(docs.iterator())); - Schema schema = SchemaUtils.inferSchema(mockClient, new TableName("test", "test"), 10); + Schema schema = SchemaUtils.inferSchema(mockDatabase, new TableName("test", "test"), 10); assertEquals(6, schema.getFields().size()); Map fields = new HashMap<>(); @@ -167,11 +164,9 @@ public void emptyListTest() doc2.put("col4", list2); docs.add(doc2); - MongoClient mockClient = mock(MongoClient.class); MongoDatabase mockDatabase = mock(MongoDatabase.class); MongoCollection mockCollection = mock(MongoCollection.class); FindIterable mockIterable = mock(FindIterable.class); - when(mockClient.getDatabase(any())).thenReturn(mockDatabase); when(mockDatabase.getCollection(any())).thenReturn(mockCollection); when(mockCollection.find()).thenReturn(mockIterable); when(mockIterable.limit(anyInt())).thenReturn(mockIterable); @@ -179,7 +174,7 @@ public void emptyListTest() when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable); when(mockIterable.iterator()).thenReturn(new StubbingCursor(docs.iterator())); - Schema schema = SchemaUtils.inferSchema(mockClient, new TableName("test", "test"), 10); + Schema schema = SchemaUtils.inferSchema(mockDatabase, new TableName("test", "test"), 10); assertEquals(4, schema.getFields().size()); Map fields = new HashMap<>();