Skip to content

Commit

Permalink
Enable case insensitive match for docdb on schema and table names (#1639
Browse files Browse the repository at this point in the history
)

Co-authored-by: ejeffrli <[email protected]>
  • Loading branch information
chngpe and ejeffrli authored Dec 7, 2023
1 parent b96aabf commit 99c3679
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*-
* #%L
* athena-mongodb
* %%
* Copyright (C) 2019 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.docdb;

import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoDatabase;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

public class DocDBCaseInsensitiveResolver
{
private static final Logger logger = LoggerFactory.getLogger(DocDBCaseInsensitiveResolver.class);

// This enable_case_insensitive match for schema and table name due to Athena lower case schema and table name.
// Capital letters are permitted for DocumentDB.
private static final String ENABLE_CASE_INSENSITIVE_MATCH = "enable_case_insensitive_match";

private DocDBCaseInsensitiveResolver() {}

public static String getSchemaNameCaseInsensitiveMatch(Map<String, String> configOptions, MongoClient client, String unresolvedSchemaName)
{
String resolvedSchemaName = unresolvedSchemaName;
if (isCaseInsensitiveMatchEnable(configOptions)) {
logger.info("CaseInsensitiveMatch enable, SchemaName input: {}", resolvedSchemaName);
List<String> candidateSchemaNames = StreamSupport.stream(client.listDatabaseNames().spliterator(), false)
.filter(candidateSchemaName -> candidateSchemaName.equalsIgnoreCase(unresolvedSchemaName))
.collect(Collectors.toList());
if (candidateSchemaNames.size() != 1) {
throw new IllegalArgumentException(String.format("Schema name is empty or more than 1 for case insensitive match. schemaName: %s, size: %d",
unresolvedSchemaName, candidateSchemaNames.size()));
}
resolvedSchemaName = candidateSchemaNames.get(0);
logger.info("CaseInsensitiveMatch, SchemaName resolved to: {}", resolvedSchemaName);
}

return resolvedSchemaName;
}

public static String getTableNameCaseInsensitiveMatch(Map<String, String> configOptions, MongoDatabase mongoDatabase, String unresolvedTableName)
{
String resolvedTableName = unresolvedTableName;
if (isCaseInsensitiveMatchEnable(configOptions)) {
logger.info("CaseInsensitiveMatch enable, TableName input: {}", resolvedTableName);
List<String> candidateTableNames = StreamSupport.stream(mongoDatabase.listCollectionNames().spliterator(), false)
.filter(candidateTableName -> candidateTableName.equalsIgnoreCase(unresolvedTableName))
.collect(Collectors.toList());
if (candidateTableNames.size() != 1) {
throw new IllegalArgumentException(String.format("Table name is empty or more than 1 for case insensitive match. schemaName: %s, size: %d",
unresolvedTableName, candidateTableNames.size()));
}
resolvedTableName = candidateTableNames.get(0);
logger.info("CaseInsensitiveMatch, TableName resolved to: {}", resolvedTableName);
}
return resolvedTableName;
}

private static boolean isCaseInsensitiveMatchEnable(Map<String, String> configOptions)
{
String enableCaseInsensitiveMatchEnvValue = configOptions.getOrDefault(ENABLE_CASE_INSENSITIVE_MATCH, "false").toLowerCase();
boolean enableCaseInsensitiveMatch = enableCaseInsensitiveMatchEnvValue.equals("true");
logger.info("{} environment variable set to: {}. Resolved to: {}",
ENABLE_CASE_INSENSITIVE_MATCH, enableCaseInsensitiveMatchEnvValue, enableCaseInsensitiveMatch);

return enableCaseInsensitiveMatch;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import com.google.common.base.Strings;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCursor;
import com.mongodb.client.MongoDatabase;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
Expand Down Expand Up @@ -233,6 +234,9 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
throws Exception
{
logger.info("doGetTable: enter", request.getTableName());
String schemaNameInput = request.getTableName().getSchemaName();
String tableNameInput = request.getTableName().getTableName();
TableName tableName = new TableName(schemaNameInput, tableNameInput);
Schema schema = null;
try {
if (glue != null) {
Expand All @@ -250,9 +254,14 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
if (schema == null) {
logger.info("doGetTable: Inferring schema for table[{}].", request.getTableName());
MongoClient client = getOrCreateConn(request);
schema = SchemaUtils.inferSchema(client, request.getTableName(), SCHEMA_INFERRENCE_NUM_DOCS);
//Attempt to update schema and table name with case insensitive match if enable
schemaNameInput = DocDBCaseInsensitiveResolver.getSchemaNameCaseInsensitiveMatch(configOptions, client, schemaNameInput);
MongoDatabase db = client.getDatabase(schemaNameInput);
tableNameInput = DocDBCaseInsensitiveResolver.getTableNameCaseInsensitiveMatch(configOptions, db, tableNameInput);
tableName = new TableName(schemaNameInput, tableNameInput);
schema = SchemaUtils.inferSchema(db, tableName, SCHEMA_INFERRENCE_NUM_DOCS);
}
return new GetTableResponse(request.getCatalogName(), request.getTableName(), schema);
return new GetTableResponse(request.getCatalogName(), tableName, schema);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.amazonaws.athena.connector.lambda.data.FieldBuilder;
import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCursor;
import com.mongodb.client.MongoDatabase;
import org.apache.arrow.vector.types.Types;
Expand Down Expand Up @@ -70,9 +69,8 @@ private SchemaUtils() {}
* to use a reasonable default (like String) and coerce heterogeneous fields to avoid query failure but forcing
* explicit handling by defining Schema in AWS Glue is likely a better approach.
*/
public static Schema inferSchema(MongoClient client, TableName table, int numObjToSample)
public static Schema inferSchema(MongoDatabase db, TableName table, int numObjToSample)
{
MongoDatabase db = client.getDatabase(table.getSchemaName());
int docCount = 0;
int fieldCount = 0;
try (MongoCursor<Document> docs = db.getCollection(table.getTableName()).find().batchSize(numObjToSample)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.glue.AWSGlue;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.collect.ImmutableList;
import com.mongodb.client.FindIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.MongoIterable;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
Expand Down Expand Up @@ -76,6 +78,7 @@
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@RunWith(MockitoJUnitRunner.class)
Expand Down Expand Up @@ -250,6 +253,175 @@ public void doGetTable()
assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(unsupported.getType()));
}

@Test
public void doGetTableCaseInsensitiveMatch()
throws Exception
{

DocDBMetadataHandler caseInsensitiveHandler = new DocDBMetadataHandler(awsGlue,
connectionFactory, new LocalKeyFactory(), secretsManager, mockAthena,
"spillBucket", "spillPrefix", com.google.common.collect.ImmutableMap.of("enable_case_insensitive_match", "true"));
List<Document> documents = new ArrayList<>();

Document doc1 = new Document();
documents.add(doc1);
doc1.put("stringCol", "stringVal");
doc1.put("intCol", 1);
doc1.put("doubleCol", 2.2D);
doc1.put("longCol", 100L);
doc1.put("unsupported", new UnsupportedType());

Document doc2 = new Document();
documents.add(doc2);
doc2.put("stringCol2", "stringVal");
doc2.put("intCol2", 1);
doc2.put("doubleCol2", 2.2D);
doc2.put("longCol2", 100L);

Document doc3 = new Document();
documents.add(doc3);
doc3.put("stringCol", "stringVal");
doc3.put("intCol2", 1);
doc3.put("doubleCol", 2.2D);
doc3.put("longCol2", 100L);

MongoDatabase mockDatabase = mock(MongoDatabase.class);
MongoCollection mockCollection = mock(MongoCollection.class);
FindIterable mockIterable = mock(FindIterable.class);

MongoIterable mockListDatabaseNamesIterable = mock(MongoIterable.class);
when(mockClient.listDatabaseNames()).thenReturn(mockListDatabaseNamesIterable);

when(mockListDatabaseNamesIterable.spliterator()).thenReturn(ImmutableList.of(DEFAULT_SCHEMA).spliterator());

MongoIterable mockListCollectionsNamesIterable = mock(MongoIterable.class);
when(mockDatabase.listCollectionNames()).thenReturn(mockListCollectionsNamesIterable);
when(mockListCollectionsNamesIterable.spliterator()).thenReturn(ImmutableList.of(TEST_TABLE).spliterator());

when(mockClient.getDatabase(eq(DEFAULT_SCHEMA))).thenReturn(mockDatabase);
when(mockDatabase.getCollection(eq(TEST_TABLE))).thenReturn(mockCollection);
when(mockCollection.find()).thenReturn(mockIterable);
when(mockIterable.limit(anyInt())).thenReturn(mockIterable);
Mockito.lenient().when(mockIterable.maxScan(anyInt())).thenReturn(mockIterable);
when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable);
when(mockIterable.iterator()).thenReturn(new StubbingCursor(documents.iterator()));

TableName tableNameInput = new TableName("DEfault", TEST_TABLE.toUpperCase());
GetTableRequest req = new GetTableRequest(IDENTITY, QUERY_ID, DEFAULT_CATALOG, tableNameInput);
GetTableResponse res = caseInsensitiveHandler.doGetTable(allocator, req);

assertEquals(DEFAULT_SCHEMA, res.getTableName().getSchemaName());
assertEquals(TEST_TABLE, res.getTableName().getTableName());
logger.info("doGetTable - {}", res);

assertEquals(9, res.getSchema().getFields().size());

Field stringCol = res.getSchema().findField("stringCol");
assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(stringCol.getType()));

Field stringCol2 = res.getSchema().findField("stringCol2");
assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(stringCol2.getType()));

Field intCol = res.getSchema().findField("intCol");
assertEquals(Types.MinorType.INT, Types.getMinorTypeForArrowType(intCol.getType()));

Field intCol2 = res.getSchema().findField("intCol2");
assertEquals(Types.MinorType.INT, Types.getMinorTypeForArrowType(intCol2.getType()));

Field doubleCol = res.getSchema().findField("doubleCol");
assertEquals(Types.MinorType.FLOAT8, Types.getMinorTypeForArrowType(doubleCol.getType()));

Field doubleCol2 = res.getSchema().findField("doubleCol2");
assertEquals(Types.MinorType.FLOAT8, Types.getMinorTypeForArrowType(doubleCol2.getType()));

Field longCol = res.getSchema().findField("longCol");
assertEquals(Types.MinorType.BIGINT, Types.getMinorTypeForArrowType(longCol.getType()));

Field longCol2 = res.getSchema().findField("longCol2");
assertEquals(Types.MinorType.BIGINT, Types.getMinorTypeForArrowType(longCol2.getType()));

Field unsupported = res.getSchema().findField("unsupported");
assertEquals(Types.MinorType.VARCHAR, Types.getMinorTypeForArrowType(unsupported.getType()));
}


@Test
public void doGetTableCaseInsensitiveMatchMultipleMatch()
throws Exception
{

DocDBMetadataHandler caseInsensitiveHandler = new DocDBMetadataHandler(awsGlue,
connectionFactory, new LocalKeyFactory(), secretsManager, mockAthena,
"spillBucket", "spillPrefix", com.google.common.collect.ImmutableMap.of("enable_case_insensitive_match", "true"));

MongoIterable mockListDatabaseNamesIterable = mock(MongoIterable.class);
when(mockClient.listDatabaseNames()).thenReturn(mockListDatabaseNamesIterable);
when(mockListDatabaseNamesIterable.spliterator()).thenReturn(ImmutableList.of(DEFAULT_SCHEMA, DEFAULT_SCHEMA.toUpperCase()).spliterator());

TableName tableNameInput = new TableName("deFAULT", TEST_TABLE.toUpperCase());
GetTableRequest req = new GetTableRequest(IDENTITY, QUERY_ID, DEFAULT_CATALOG, tableNameInput);
try {
GetTableResponse res = caseInsensitiveHandler.doGetTable(allocator, req);
fail("doGetTableCaseInsensitiveMatchMultipleMatch should failed");
} catch(IllegalArgumentException ex){
assertEquals("Schema name is empty or more than 1 for case insensitive match. schemaName: deFAULT, size: 2", ex.getMessage());
}
}

@Test
public void doGetTableCaseInsensitiveMatchNotEnable()
throws Exception
{

String mixedCaseSchemaName = "deFAULT";
String mixedCaseTableName = "tesT_Table";
List<Document> documents = new ArrayList<>();

Document doc1 = new Document();
documents.add(doc1);
doc1.put("stringCol", "stringVal");
doc1.put("intCol", 1);
doc1.put("doubleCol", 2.2D);
doc1.put("longCol", 100L);
doc1.put("unsupported", new UnsupportedType());

Document doc2 = new Document();
documents.add(doc2);
doc2.put("stringCol2", "stringVal");
doc2.put("intCol2", 1);
doc2.put("doubleCol2", 2.2D);
doc2.put("longCol2", 100L);

Document doc3 = new Document();
documents.add(doc3);
doc3.put("stringCol", "stringVal");
doc3.put("intCol2", 1);
doc3.put("doubleCol", 2.2D);
doc3.put("longCol2", 100L);

MongoDatabase mockDatabase = mock(MongoDatabase.class);
MongoCollection mockCollection = mock(MongoCollection.class);
FindIterable mockIterable = mock(FindIterable.class);
when(mockClient.getDatabase(eq(mixedCaseSchemaName))).thenReturn(mockDatabase);
when(mockDatabase.getCollection(eq(mixedCaseTableName))).thenReturn(mockCollection);
when(mockCollection.find()).thenReturn(mockIterable);
when(mockIterable.limit(anyInt())).thenReturn(mockIterable);
Mockito.lenient().when(mockIterable.maxScan(anyInt())).thenReturn(mockIterable);
when(mockIterable.batchSize(anyInt())).thenReturn(mockIterable);
when(mockIterable.iterator()).thenReturn(new StubbingCursor(documents.iterator()));

TableName tableNameInput = new TableName(mixedCaseSchemaName, mixedCaseTableName);
GetTableRequest req = new GetTableRequest(IDENTITY, QUERY_ID, DEFAULT_CATALOG, tableNameInput);
GetTableResponse res = handler.doGetTable(allocator, req);

assertEquals(mixedCaseSchemaName, res.getTableName().getSchemaName());
assertEquals(mixedCaseTableName, res.getTableName().getTableName());

verify(mockClient, Mockito.never()).listDatabaseNames();
verify(mockDatabase, Mockito.never()).listCollectionNames();

}

@Test
public void doGetTableLayout()
throws Exception
Expand Down
Loading

0 comments on commit 99c3679

Please sign in to comment.