Skip to content

Commit

Permalink
Fix a data loss bug in MergeIntoCommand
Browse files Browse the repository at this point in the history
Signed-off-by: Eunjin Song <[email protected]>
  • Loading branch information
sezruby committed Sep 30, 2023
1 parent d42a22d commit 09c9d83
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ case class PreprocessTableMerge(override val conf: SQLConf)
(matched ++ notMatched).filter(_.condition.nonEmpty).foreach { clause =>
checkCondition(clause.condition.get, clause.clauseType.toUpperCase(Locale.ROOT))
}

val shouldAutoMigrate = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) && migrateSchema
val canMergeSchema = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE)
val shouldAutoMigrate = canMergeSchema && migrateSchema
val finalSchema = if (shouldAutoMigrate) {
// The implicit conversions flag allows any type to be merged from source to target if Spark
// SQL considers the source type implicitly castable to the target. Normally, mergeSchemas
Expand Down Expand Up @@ -208,6 +208,6 @@ case class PreprocessTableMerge(override val conf: SQLConf)

MergeIntoCommand(
source, target, tahoeFileIndex, condition,
processedMatched, processedNotMatched, Some(finalSchema))
processedMatched, processedNotMatched, Some(finalSchema), canMergeSchema)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ object MergeStats {
* @param matchedClauses All info related to matched clauses.
* @param notMatchedClauses All info related to not matched clause.
* @param migratedSchema The final schema of the target - may be changed by schema evolution.
* @param autoSchemaMergeEnabled Auto schema merge enabled config used in PreprocessTableMerge.
*/
case class MergeIntoCommand(
@transient source: LogicalPlan,
Expand All @@ -205,13 +206,14 @@ case class MergeIntoCommand(
condition: Expression,
matchedClauses: Seq[DeltaMergeIntoMatchedClause],
notMatchedClauses: Seq[DeltaMergeIntoInsertClause],
migratedSchema: Option[StructType]) extends RunnableCommand
migratedSchema: Option[StructType],
autoSchemaMergeEnabled: Boolean) extends RunnableCommand
with DeltaCommand with PredicateHelper with AnalysisHelper with ImplicitMetadataOperation {

import SQLMetrics._
import MergeIntoCommand._

override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE)
override val canMergeSchema: Boolean = autoSchemaMergeEnabled
override val canOverwriteSchema: Boolean = false

@transient private lazy val sc: SparkContext = SparkContext.getOrCreate()
Expand Down Expand Up @@ -265,13 +267,14 @@ case class MergeIntoCommand(
isOverwriteMode = false, rearrangeOnly = false)
}

val targetOutputCols = getTargetOutputCols(deltaTxn, spark)
val deltaActions = {
if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) {
writeInsertsOnlyWhenNoMatchedClauses(spark, deltaTxn)
writeInsertsOnlyWhenNoMatchedClauses(spark, deltaTxn, targetOutputCols)
} else {
val filesToRewrite = findTouchedFiles(spark, deltaTxn)
val filesToRewrite = findTouchedFiles(spark, deltaTxn, targetOutputCols)
val newWrittenFiles = withStatusCode("DELTA", "Writing merged data") {
writeAllChanges(spark, deltaTxn, filesToRewrite)
writeAllChanges(spark, deltaTxn, filesToRewrite, targetOutputCols)
}
filesToRewrite.map(_.remove) ++ newWrittenFiles
}
Expand Down Expand Up @@ -309,9 +312,9 @@ case class MergeIntoCommand(
*/
private def findTouchedFiles(
spark: SparkSession,
deltaTxn: OptimisticTransaction
deltaTxn: OptimisticTransaction,
targetOutputCols: Seq[NamedExpression]
): Seq[AddFile] = recordMergeOperation(sqlMetricName = "scanTimeMs") {

// Accumulator to collect all the distinct touched files
val touchedFilesAccum = new SetAccumulator[String]()
spark.sparkContext.register(touchedFilesAccum, TOUCHED_FILES_ACCUM_NAME)
Expand All @@ -334,7 +337,9 @@ case class MergeIntoCommand(
// - the target file name the row is from to later identify the files touched by matched rows
val joinToFindTouchedFiles = {
val sourceDF = Dataset.ofRows(spark, source)
val targetDF = Dataset.ofRows(spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles))
val targetDF = Dataset
.ofRows(spark,
buildTargetPlanWithFiles(deltaTxn, spark, dataSkippedFiles, targetOutputCols))
.withColumn(ROW_ID_COL, monotonically_increasing_id())
.withColumn(FILE_NAME_COL, input_file_name())
sourceDF.join(targetDF, new Column(condition), "inner")
Expand Down Expand Up @@ -396,14 +401,15 @@ case class MergeIntoCommand(
*/
private def writeInsertsOnlyWhenNoMatchedClauses(
spark: SparkSession,
deltaTxn: OptimisticTransaction
deltaTxn: OptimisticTransaction,
targetOutputCols: Seq[NamedExpression]
): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") {

// UDFs to update metrics
val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows")
val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted")

val outputColNames = getTargetOutputCols(deltaTxn).map(_.name)
val outputColNames = targetOutputCols.map(_.name)
// we use head here since we know there is only a single notMatchedClause
val outputExprs = notMatchedClauses.head.resolvedActions.map(_.expr) :+ incrInsertedCountExpr
val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) =>
Expand All @@ -423,7 +429,7 @@ case class MergeIntoCommand(

// target DataFrame
val targetDF = Dataset.ofRows(
spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles))
spark, buildTargetPlanWithFiles(deltaTxn, spark, dataSkippedFiles, targetOutputCols))

val insertDf = sourceDF.join(targetDF, new Column(condition), "leftanti")
.select(outputCols: _*)
Expand Down Expand Up @@ -456,13 +462,13 @@ case class MergeIntoCommand(
private def writeAllChanges(
spark: SparkSession,
deltaTxn: OptimisticTransaction,
filesToRewrite: Seq[AddFile]
filesToRewrite: Seq[AddFile],
targetOutputCols: Seq[NamedExpression]
): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") {
val targetOutputCols = getTargetOutputCols(deltaTxn)

// Generate a new logical plan that has same output attributes exprIds as the target plan.
// This allows us to apply the existing resolved update/insert expressions.
val newTarget = buildTargetPlanWithFiles(deltaTxn, filesToRewrite)
val newTarget = buildTargetPlanWithFiles(deltaTxn, spark, filesToRewrite, targetOutputCols)
val joinType = if (isMatchedOnly &&
spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED)) {
"rightOuter"
Expand Down Expand Up @@ -568,8 +574,9 @@ case class MergeIntoCommand(
*/
private def buildTargetPlanWithFiles(
deltaTxn: OptimisticTransaction,
files: Seq[AddFile]): LogicalPlan = {
val targetOutputCols = getTargetOutputCols(deltaTxn)
spark: SparkSession,
files: Seq[AddFile],
targetOutputCols: Seq[NamedExpression]): LogicalPlan = {
val plan = {
// We have to do surgery to use the attributes from `targetOutputCols` to scan the table.
// In cases of schema evolution, they may not be the same type as the original attributes.
Expand All @@ -590,12 +597,12 @@ case class MergeIntoCommand(
// create an alias
val aliases = plan.output.map {
case newAttrib: AttributeReference =>
val existingTargetAttrib = getTargetOutputCols(deltaTxn).find { col =>
conf.resolver(col.name, newAttrib.name)
val existingTargetAttrib = targetOutputCols.find { col =>
spark.sessionState.conf.resolver(col.name, newAttrib.name)
}.getOrElse {
throw new AnalysisException(
s"Could not find ${newAttrib.name} among the existing target output " +
s"${getTargetOutputCols(deltaTxn)}")
s"$targetOutputCols")
}.asInstanceOf[AttributeReference]

if (existingTargetAttrib.exprId == newAttrib.exprId) {
Expand All @@ -619,9 +626,11 @@ case class MergeIntoCommand(

private def seqToString(exprs: Seq[Expression]): String = exprs.map(_.sql).mkString("\n\t")

private def getTargetOutputCols(txn: OptimisticTransaction): Seq[NamedExpression] = {
private def getTargetOutputCols(
txn: OptimisticTransaction,
spark: SparkSession): Seq[NamedExpression] = {
txn.metadata.schema.map { col =>
target.output.find(attr => conf.resolver(attr.name, col.name)).map { a =>
target.output.find(attr => spark.sessionState.conf.resolver(attr.name, col.name)).map { a =>
AttributeReference(col.name, col.dataType, col.nullable)(a.exprId)
}.getOrElse(
Alias(Literal(null), col.name)())
Expand Down Expand Up @@ -712,7 +721,7 @@ object MergeIntoCommand {
val outputProj = UnsafeProjection.create(outputRowEncoder.schema)

def shouldDeleteRow(row: InternalRow): Boolean =
row.getBoolean(outputRowEncoder.schema.fields.size)
row.getBoolean(row.numFields - 2)

def processRow(inputRow: InternalRow): InternalRow = {
if (targetRowHasNoMatchPred.eval(inputRow)) {
Expand Down

0 comments on commit 09c9d83

Please sign in to comment.