Skip to content

Commit

Permalink
null values properly returned (#3)
Browse files Browse the repository at this point in the history
* null values properly returned

* create joined row object
  • Loading branch information
jaebahk authored Jul 9, 2019
1 parent 3102507 commit 9edb9b3
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ case class MergeAsOfJoinExec(
Seq(rightBy, rightOn).map(SortOrder(_, Ascending)) :: Nil
}

private val joinedRow = new JoinedRow()
private val emptyVal: Array[Any] = Array.fill(right.output.length)(null)
private def rDummy = InternalRow(emptyVal: _*)
private def joinedRow = new JoinedRow()

private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder])
: Seq[SortOrder] = {
Expand All @@ -77,7 +77,6 @@ case class MergeAsOfJoinExec(
tolerance: Duration,
resultProj: InternalRow => InternalRow
): Iterator[InternalRow] = {
// val joinedRow = new JoinedRow()
var rHead = if (currRight._2.hasNext) {
currRight._2.next()
} else {
Expand Down Expand Up @@ -110,15 +109,12 @@ case class MergeAsOfJoinExec(
protected override def doExecute(): RDD[InternalRow] = {

val duration = Duration(tolerance)
val inputSchema = left.output ++ right.output

left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
// val joinedRow = new JoinedRow()
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, inputSchema)
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
if (!leftIter.hasNext || !rightIter.hasNext) {
leftIter.map(r => resultProj(joinedRow(r, rDummy)))
} else {
val joinedRow = new JoinedRow()
val rightGroupedIterator =
GroupedIterator(rightIter, Seq(rightBy), right.output)

Expand Down
94 changes: 40 additions & 54 deletions sql/core/src/test/scala/org/apache/spark/sql/MergeAsOfSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@ import org.apache.spark.sql.types.StructType
class MergeAsOfSuite extends QueryTest with SharedSQLContext{
import testImplicits._

setupTestData()

def statisticSizeInByte(df: DataFrame): BigInt = {
df.queryExecution.optimizedPlan.stats.sizeInBytes
}

test("basic merge") {
val df1 = Seq(
(2001, 1, 1.0),
Expand All @@ -53,25 +47,23 @@ class MergeAsOfSuite extends QueryTest with SharedSQLContext{
(2001, 2, 5),
).toDF("time", "id", "v2")

val res = df1.mergeAsOf(df2, df1("time"), df2("time"), df1("id"), df2("id"))

val expected = Seq(
(2001, 1, 1.0, 4),
(2002, 1, 1.2, 4),
(2001, 2, 1.1, 5)
).toDF("time", "id", "v", "v2")

assert(res.collect() === expected.collect())

val res2 = df1.select("time", "id").mergeAsOf(df2.withColumn("v3", df2("v2") * 3 cast "Int"), df1("time"), df2("time"), df1("id"), df2("id"))

val expected2 = Seq(
(2001, 1, 4, 12),
(2002, 1, 4, 12),
(2001, 2, 5, 15)
).toDF("time", "id", "v2", "v3")

assert(res2.collect() === expected2.collect())
checkAnswer(
df1.mergeAsOf(df2, df1("time"), df2("time"), df1("id"), df2("id")),
Seq(
Row(2001, 1, 1.0, 4),
Row(2002, 1, 1.2, 4),
Row(2001, 2, 1.1, 5)
))

checkAnswer(
df1.select("time", "id").mergeAsOf(
df2.withColumn("v3", df2("v2") * 3 cast "Int"), df1("time"), df2("time"), df1("id"), df2("id")
),
Seq(
Row(2001, 1, 4, 12),
Row(2002, 1, 4, 12),
Row(2001, 2, 5, 15)
))
}

test("default merge_asof") {
Expand All @@ -96,17 +88,15 @@ class MergeAsOfSuite extends QueryTest with SharedSQLContext{

val res = trades.mergeAsOf(quotes, trades("time"), quotes("time"), trades("ticker"), quotes("ticker"), "2ms")

val expected = Seq(
(23, "MSFT", 51.95, 75, 51.95, 51.96),
(38, "MSFT", 51.95, 155, 51.97, 51.98),
(48, "GOOG", 720.77, 100, 720.5, 720.93),
(48, "GOOG", 720.92, 100, 720.5, 720.93),
(48, "AAPL", 98.0, 100, 0.0, 0.0)
).toDF("time", "ticker", "price", "quantity", "bid", "ask")

res.show()
expected.show()
// println(res.collect() === expected.collect()) // TODO sort results by key
checkAnswer(
trades.mergeAsOf(quotes, trades("time"), quotes("time"), trades("ticker"), quotes("ticker"), "2ms"),
Seq(
Row(23, "MSFT", 51.95, 75, 51.95, 51.96),
Row(38, "MSFT", 51.95, 155, 51.97, 51.98),
Row(48, "GOOG", 720.77, 100, 720.5, 720.93),
Row(48, "GOOG", 720.92, 100, 720.5, 720.93),
Row(48, "AAPL", 98.0, 100, null, null)
))
}

test("partial key mismatch") {
Expand All @@ -121,15 +111,13 @@ class MergeAsOfSuite extends QueryTest with SharedSQLContext{
(2001, 4, 4),
).toDF("time", "id", "v2")

val res = df1.mergeAsOf(df2, df1("time"), df2("time"), df1("id"), df2("id"))

val expected = Seq(
(2001, 1, 1.0, 5),
(2002, 1, 1.2, 5),
(2001, 2, 1.1, 0)
).toDF("time", "id", "v", "v2")

assert(res.collect() === expected.collect())
checkAnswer(
df1.mergeAsOf(df2, df1("time"), df2("time"), df1("id"), df2("id")),
Seq(
Row(2001, 1, 1.0, 5),
Row(2002, 1, 1.2, 5),
Row(2001, 2, 1.1, null)
))
}

test("complete key mismatch") {
Expand All @@ -144,14 +132,12 @@ class MergeAsOfSuite extends QueryTest with SharedSQLContext{
(2001, 4, 4),
).toDF("time", "id", "v2")

val res = df1.mergeAsOf(df2, df1("time"), df2("time"), df1("id"), df2("id"))

val expected = Seq(
(2001, 1, 1.0, 0),
(2002, 1, 1.2, 0),
(2001, 2, 1.1, 0)
).toDF("time", "id", "v", "v2")

assert(res.collect() === expected.collect())
checkAnswer(
df1.mergeAsOf(df2, df1("time"), df2("time"), df1("id"), df2("id")),
Seq(
Row(2001, 1, 1.0, null),
Row(2002, 1, 1.2, null),
Row(2001, 2, 1.1, null)
))
}
}

0 comments on commit 9edb9b3

Please sign in to comment.