Skip to content

Commit

Permalink
Fix error with multiple nested partition columns on Iceberg
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyang_li committed Jan 4, 2025
1 parent 2ef6dc1 commit 5269716
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.plugin.iceberg;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;
import io.trino.spi.connector.ConnectorPartitioningHandle;
Expand All @@ -22,7 +23,6 @@
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.types.Types;

import java.util.ArrayDeque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -64,44 +64,43 @@ private static Map<Integer, List<Integer>> buildDataPaths(PartitionSpec spec)
{
Set<Integer> partitionFieldIds = spec.fields().stream().map(PartitionField::sourceId).collect(toImmutableSet());

int channel = 0;
Map<Integer, List<Integer>> fieldInfo = new HashMap<>();
for (Types.NestedField field : spec.schema().asStruct().fields()) {
// Partition fields can only be nested in a struct
if (field.type() instanceof Types.StructType nestedStruct) {
if (buildDataPaths(partitionFieldIds, nestedStruct, new ArrayDeque<>(List.of(channel)), fieldInfo)) {
channel++;
}
buildDataPaths(partitionFieldIds, nestedStruct, fieldInfo);
}
else if (field.type().isPrimitiveType() && partitionFieldIds.contains(field.fieldId())) {
fieldInfo.put(field.fieldId(), ImmutableList.of(channel));
channel++;
fieldInfo.put(field.fieldId(), ImmutableList.of());
}
}
return fieldInfo;

// assign channel based on the order of the fields
List<Integer> keys = fieldInfo.keySet().stream().sorted().collect(toImmutableList());
ImmutableMap.Builder<Integer, List<Integer>> builder = ImmutableMap.builder();
for (int channel = 0; channel < keys.size(); channel++) {
int fieldId = keys.get(channel);
builder.put(fieldId, ImmutableList.<Integer>builder().add(channel).addAll(fieldInfo.get(fieldId)).build());
}
return builder.buildOrThrow();
}

private static boolean buildDataPaths(Set<Integer> partitionFieldIds, Types.StructType struct, ArrayDeque<Integer> currentPaths, Map<Integer, List<Integer>> dataPaths)
private static void buildDataPaths(Set<Integer> partitionFieldIds, Types.StructType struct, Map<Integer, List<Integer>> dataPaths)
{
boolean hasPartitionFields = false;
List<Types.NestedField> fields = struct.fields();
for (int fieldOrdinal = 0; fieldOrdinal < fields.size(); fieldOrdinal++) {
Types.NestedField field = fields.get(fieldOrdinal);
int fieldId = field.fieldId();

currentPaths.addLast(fieldOrdinal);
org.apache.iceberg.types.Type type = field.type();
if (type instanceof Types.StructType nestedStruct) {
hasPartitionFields = buildDataPaths(partitionFieldIds, nestedStruct, currentPaths, dataPaths) || hasPartitionFields;
buildDataPaths(partitionFieldIds, nestedStruct, dataPaths);
}
// Map and List types are not supported in partitioning
if (type.isPrimitiveType() && partitionFieldIds.contains(fieldId)) {
dataPaths.put(fieldId, ImmutableList.copyOf(currentPaths));
hasPartitionFields = true;
dataPaths.put(fieldId, ImmutableList.of(fieldOrdinal));
}
currentPaths.removeLast();
}
return hasPartitionFields;
}

public long getCacheKeyHint()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.testing.MaterializedResult.DEFAULT_PRECISION;
import static io.trino.testing.MaterializedResult.resultBuilder;
import static io.trino.testing.TestingNames.randomNameSuffix;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -683,4 +684,24 @@ private BaseTable loadTable(String tableName)
{
return IcebergTestUtils.loadTable(tableName, metastore, fileSystemFactory, "hive", "tpch");
}

@Test
public void testPartitionColumns()
{
String tableName = "test_partition_columns_" + randomNameSuffix();
assertUpdate(String.format("""
CREATE TABLE %s WITH (partitioning = ARRAY[
'"r1.f1"',
'bucket(b1, 4)'
]) AS
SELECT
CAST('c1' AS VARCHAR) as c1
, CAST(ROW(1, 2) AS ROW(f1 integer, f2 integer)) as r1
, CAST('2022-01-01 01:01:01' AS TIMESTAMP) as d1
, CAST('2022-01-01 01:01:01' AS TIMESTAMP) as d2
, CAST('2022-01-01 01:01:01' AS TIMESTAMP) as d3
, CAST('2022-01-01 01:01:01' AS TIMESTAMP) as d4
, CAST('b' AS VARCHAR) as b1
, CAST('12345678' AS VARCHAR) as t1""", tableName), 1);
}
}

0 comments on commit 5269716

Please sign in to comment.