Skip to content

Commit

Permalink
Add extra schema check
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Feb 15, 2024
1 parent 1d49f8e commit 1dd677b
Showing 1 changed file with 30 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@ import com.spotify.scio.io.ClosedTap
import com.spotify.scio.parquet.avro._
import com.spotify.scio.values.SCollection
import com.twitter.algebird._
import org.apache.avro.{Schema, SchemaCompatibility}
import org.apache.avro.generic.GenericRecord
import org.apache.avro.specific.SpecificRecordBase
import org.apache.beam.sdk.io.TextIO
import org.apache.beam.sdk.options.PipelineOptions
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.BaseEncoding
import org.slf4j.{Logger, LoggerFactory}

import java.io.File
import java.nio.ByteBuffer
import scala.annotation.tailrec
import scala.collection.compat.BuildFrom.fromCanBuildFromConversion
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.language.higherKinds
Expand Down Expand Up @@ -606,6 +610,9 @@ object BigDiffy extends Command with Serializable {
sys.exit(1)
}

private def avroFileSchema(path: String, options: PipelineOptions): Schema =
new AvroSampler(path, conf = Some(options)).sample(1, head = true).head.getSchema

private[diffy] def avroKeyFn(keys: Seq[String]): GenericRecord => MultiKey = {
@tailrec
def get(xs: Array[String], i: Int, r: GenericRecord): String =
Expand Down Expand Up @@ -743,13 +750,30 @@ object BigDiffy extends Command with Serializable {
val result = inputMode match {
case "avro" =>
if (rowRestriction.isDefined) {
throw new IllegalArgumentException(s"rowRestriction cannot be passed for avro inputs")
throw new IllegalArgumentException("rowRestriction cannot be passed for avro inputs")
}

val lhsSchema = avroFileSchema(lhs, sc.options)
val rhsSchema = avroFileSchema(rhs, sc.options)

val lhsReader = SchemaCompatibility.checkReaderWriterCompatibility(lhsSchema, rhsSchema)
val rhsReader = SchemaCompatibility.checkReaderWriterCompatibility(rhsSchema, lhsSchema)

import SchemaCompatibility.SchemaCompatibilityType._
val schema = (lhsReader.getType, rhsReader.getType) match {
case (COMPATIBLE, COMPATIBLE) =>
if (lhsSchema != rhsSchema) {
logger.warn("Avro schemas are compatible, but not equal. Using schema from {}", lhs)
}
lhsSchema
case (COMPATIBLE, INCOMPATIBLE) =>
lhsSchema
case (INCOMPATIBLE, COMPATIBLE) =>
rhsSchema
case _ =>
throw new IllegalArgumentException("Avro schemas are incompatible")
}

val schema = new AvroSampler(rhs, conf = Some(sc.options))
.sample(1, head = true)
.head
.getSchema
implicit val grCoder: Coder[GenericRecord] = avroGenericRecordCoder(schema)
val diffy = new AvroDiffy[GenericRecord](ignore, unordered, unorderedKeys)
val lhsSCollection = sc.avroFile(lhs, schema)
Expand All @@ -758,7 +782,7 @@ object BigDiffy extends Command with Serializable {
.diff[GenericRecord](lhsSCollection, rhsSCollection, diffy, avroKeyFn(keys), ignoreNan)
case "parquet" =>
if (rowRestriction.isDefined) {
throw new IllegalArgumentException(s"rowRestriction cannot be passed for Parquet inputs")
throw new IllegalArgumentException("rowRestriction cannot be passed for Parquet inputs")
}
val compatSchema = ParquetIO.getCompatibleSchemaForFiles(lhs, rhs)
val diffy = new AvroDiffy[GenericRecord](ignore, unordered, unorderedKeys)(
Expand Down

0 comments on commit 1dd677b

Please sign in to comment.