From 5a0ab2a0abdfa729fd0c8f390071ebf0551ddb03 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Wed, 13 Nov 2024 13:56:54 +0100 Subject: [PATCH] Remodel CastingBehavior type --- .../sql/delta/PreprocessTableMerge.scala | 2 +- .../sql/delta/UpdateExpressionsSupport.scala | 68 ++++++++++++------- .../delta/commands/MergeIntoCommandBase.scala | 2 +- .../spark/sql/delta/sources/DeltaSink.scala | 6 +- .../DeltaInsertIntoColumnOrderSuite.scala | 21 +++--- 5 files changed, 60 insertions(+), 39 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala b/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala index af5a035e58..2dc43ccdb6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala @@ -185,7 +185,7 @@ case class PreprocessTableMerge(override val conf: SQLConf) castIfNeeded( a.expr, targetAttrib.dataType, - castingBehavior = CastingBehavior.forMergeOrUpdate(withSchemaEvolution), + castingBehavior = MergeOrUpdateCastingBehavior(withSchemaEvolution), targetAttrib.name), targetColNameResolved = true) }.getOrElse { diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala b/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala index 68c58af010..3e23cb25ec 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala @@ -46,30 +46,49 @@ trait UpdateExpressionsSupport extends SQLConfHelper with AnalysisHelper with De case class UpdateOperation(targetColNameParts: Seq[String], updateExpr: Expression) /** + * The following trait and classes define casting behaviors to use in `castIfNeeded()`. + * @param resolveStructsByName Whether struct fields should be resolved by name or by position + * during struct cast. * @param allowMissingStructField Whether missing struct fields are allowed in the data to cast. + * Only relevant when struct fields are resolved by name. * When true, missing struct fields in the input are set to null. * When false, an error is thrown. * Note: this should be set to true for schema evolution to work as * the target schema may typically contain new struct fields not * present in the input. - * @param resolveStructsByName Whether struct fields should be resolved by name or by position - * during struct cast. - * @param isMergeOrUpdate Allow differentiating between merge/update and insert operations - * in error messages and to provide backward compatible behavior. */ - case class CastingBehavior( - allowMissingStructField: Boolean, - resolveStructsByName: Boolean, - isMergeOrUpdate: Boolean - ) - - object CastingBehavior { - def forMergeOrUpdate(allowSchemaEvolution: Boolean): CastingBehavior = - CastingBehavior( - allowMissingStructField = allowSchemaEvolution, - resolveStructsByName = - conf.getConf(DeltaSQLConf.DELTA_RESOLVE_MERGE_UPDATE_STRUCTS_BY_NAME), - isMergeOrUpdate = true) + sealed trait CastingBehavior { + val resolveStructsByName: Boolean + val allowMissingStructField: Boolean + } + + case class CastByPosition() extends CastingBehavior { + val resolveStructsByName: Boolean = false + val allowMissingStructField: Boolean = false + } + + case class CastByName(allowMissingStructField: Boolean) extends CastingBehavior { + val resolveStructsByName: Boolean = true + } + + /* + * MERGE and UPDATE casting behavior is configurable using internal configs to allow reverting to + * legacy behavior. In particular: + * - 'resolveMergeUpdateStructsByName.enabled': defaults to resolution by name for struct fields, + * can be disabled to revert to resolution by position. + * - 'updateAndMergeCastingFollowsAnsiEnabledFlag': defaults to following + * 'spark.sql.storeAssignmentPolicy' for the type of cast to use, can be enabled to revert to + * following 'spark.sql.ansi.enabled'. See `cast()` below. + */ + trait MergeOrUpdateCastingBehavior + object MergeOrUpdateCastingBehavior { + def apply(schemaEvolutionEnabled: Boolean): CastingBehavior = + if (conf.getConf(DeltaSQLConf.DELTA_RESOLVE_MERGE_UPDATE_STRUCTS_BY_NAME)) { + new CastByName(allowMissingStructField = schemaEvolutionEnabled) + with MergeOrUpdateCastingBehavior + } else { + new CastByPosition() with MergeOrUpdateCastingBehavior + } } /** @@ -334,7 +353,7 @@ trait UpdateExpressionsSupport extends SQLConfHelper with AnalysisHelper with De Some(castIfNeeded( fullyMatchedOp.get.updateExpr, targetCol.dataType, - castingBehavior = CastingBehavior.forMergeOrUpdate(allowSchemaEvolution), + castingBehavior = MergeOrUpdateCastingBehavior(allowSchemaEvolution), targetCol.name)) } else { // So there are prefix-matched update operations, but none of them is a full match. Then @@ -477,7 +496,7 @@ trait UpdateExpressionsSupport extends SQLConfHelper with AnalysisHelper with De dataType: DataType, castingBehavior: CastingBehavior, columnName: String): Expression = { - if (castingBehavior.isMergeOrUpdate && + if (castingBehavior.isInstanceOf[MergeOrUpdateCastingBehavior] && conf.getConf(DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG)) { return Cast(child, dataType, Option(conf.sessionLocalTimeZone)) } @@ -488,11 +507,12 @@ trait UpdateExpressionsSupport extends SQLConfHelper with AnalysisHelper with De case SQLConf.StoreAssignmentPolicy.ANSI => val cast = Cast(child, dataType, Some(conf.sessionLocalTimeZone), ansiEnabled = true) if (canCauseCastOverflow(cast)) { - if (castingBehavior.isMergeOrUpdate) { - CheckOverflowInTableWrite(cast, columnName) - } else { - cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) - CheckOverflowInTableInsert(cast, columnName) + castingBehavior match { + case _: MergeOrUpdateCastingBehavior => + CheckOverflowInTableWrite(cast, columnName) + case _ => + cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) + CheckOverflowInTableInsert(cast, columnName) } } else { cast diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommandBase.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommandBase.scala index a67eeb08d8..b2ccdf715e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommandBase.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/MergeIntoCommandBase.scala @@ -119,7 +119,7 @@ trait MergeIntoCommandBase extends LeafRunnableCommand castIfNeeded( attr.withNullability(attr.nullable || makeNullable), col.dataType, - castingBehavior = CastingBehavior.forMergeOrUpdate(canMergeSchema), + castingBehavior = MergeOrUpdateCastingBehavior(canMergeSchema), col.name), col.name )() diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSink.scala b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSink.scala index ef664d544b..4ad651abd0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSink.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSink.scala @@ -202,11 +202,7 @@ case class DeltaSink( val castExpr = castIfNeeded( fromExpression = data.col(columnName).expr, dataType = targetTypes(columnName), - castingBehavior = CastingBehavior( - allowMissingStructField = true, - resolveStructsByName = true, - isMergeOrUpdate = false - ), + castingBehavior = CastByName(allowMissingStructField = true), columnName = columnName ) Column(Alias(castExpr, columnName)()) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoColumnOrderSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoColumnOrderSuite.scala index 08cf18088a..15c8249df2 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoColumnOrderSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoColumnOrderSuite.scala @@ -33,6 +33,13 @@ class DeltaInsertIntoColumnOrderSuite extends DeltaInsertIntoTest { spark.conf.set(SQLConf.ANSI_ENABLED.key, "true") } + /** Collects inserts that don't support implicit casting and will fail if the input data type + * doesn't match the expected column type. + * These are all dataframe inserts that use by name resolution, except for streaming writes. + */ + private val insertsWithoutImplicitCastSupport: Set[Insert] = + insertsByName.intersect(insertsDataframe) - StreamingInsert + test("all test cases are implemented") { checkAllTestCasesImplemented() } @@ -68,9 +75,8 @@ class DeltaInsertIntoColumnOrderSuite extends DeltaInsertIntoTest { overwriteWhere = "a" -> 1, insertData = TestData("a long, c int, b int", Seq("""{ "a": 1, "c": 4, "b": 5 }""")), expectedResult = ExpectedResult.Success(expectedAnswer), - // Exclude dataframe inserts by name (except streaming) which don't support implicit cast. - // See negative test below. - includeInserts = inserts -- (insertsByName.intersect(insertsDataframe) - StreamingInsert) + // Inserts that don't support implicit cast are failing, these are covered in the test below. + includeInserts = inserts -- insertsWithoutImplicitCastSupport ) } @@ -87,7 +93,7 @@ class DeltaInsertIntoColumnOrderSuite extends DeltaInsertIntoTest { "currentField" -> "a", "updateField" -> "a" ))}), - includeInserts = insertsByName.intersect(insertsDataframe) - StreamingInsert + includeInserts = insertsWithoutImplicitCastSupport ) // Inserting using a different ordering for struct fields is full of surprises... @@ -149,9 +155,8 @@ class DeltaInsertIntoColumnOrderSuite extends DeltaInsertIntoTest { insertData = TestData("a long, s struct ", Seq("""{ "a": 1, "s": { "y": 5, "x": 4 } }""")), expectedResult = ExpectedResult.Success(expectedAnswer), - // Exclude dataframe inserts by name (except streaming) which don't support implicit cast. - // See negative test below. - includeInserts = inserts -- (insertsByName.intersect(insertsDataframe) - StreamingInsert) + // Inserts that don't support implicit cast are failing, these are covered in the test below. + includeInserts = inserts -- insertsWithoutImplicitCastSupport ) } @@ -171,6 +176,6 @@ class DeltaInsertIntoColumnOrderSuite extends DeltaInsertIntoTest { "currentField" -> "a", "updateField" -> "a" ))}), - includeInserts = insertsDataframe.intersect(insertsByName) - StreamingInsert + includeInserts = insertsWithoutImplicitCastSupport ) }