Skip to content

Commit

Permalink
initial implementation of merge_asof (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaebahk authored Jul 2, 2019
1 parent 4ebff5b commit 3102507
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class Analyzer(
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalAndUsingJoin ::
ResolveMergeAsof ::
ResolveOutputRelation ::
ExtractWindowExpressions ::
GlobalAggregates ::
Expand Down Expand Up @@ -2286,6 +2287,21 @@ class Analyzer(
}
}

object ResolveMergeAsof extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan match {
case m @ MergeAsOf(left, right, leftOn, rightOn, leftBy, rightBy, tolerance)
if left.resolved && right.resolved && m.duplicateResolved => {
val lUniqueOutput = left.output.filterNot(att => leftBy == att || leftOn == att)
val rUniqueOutput = right.output.filterNot(att => rightBy == att || rightOn == att)

val output = Seq(leftOn.asInstanceOf[Attribute]) ++
leftBy.map {expr => expr.asInstanceOf[Attribute]} ++ lUniqueOutput ++ rUniqueOutput
Project(output, plan)
}
case _ => plan
}
}

/**
* Resolves columns of an output table from the data in a logical plan. This rule will:
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,29 @@ case class Join(
}
}

object MergeAsOf {
def apply(left: LogicalPlan, right: LogicalPlan, leftOn: Expression, rightOn: Expression,
leftBy: Expression, rightBy: Expression, tolerance: String): MergeAsOf = {
new MergeAsOf(left, right, leftOn, rightOn, leftBy, rightBy, tolerance)
}
}

case class MergeAsOf(
left: LogicalPlan,
right: LogicalPlan,
leftOn: Expression,
rightOn: Expression,
leftBy: Expression,
rightBy: Expression,
tolerance: String)
extends BinaryNode {

// TODO polymorphic keys
override def output: Seq[Attribute] = left.output ++ right.output.map(_.withNullability(true))

def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
}

/**
* Base trait for DataSourceV2 write commands
*/
Expand Down
13 changes: 13 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,19 @@ class Dataset[T] private[sql](
joinWith(other, condition, "inner")
}

def mergeAsOf[U](
right: Dataset[U],
leftOn: Column,
rightOn: Column,
leftBy: Column,
rightBy: Column,
tolerance: String = "0ms"): DataFrame = {
withPlan {
MergeAsOf(logicalPlan, right.logicalPlan, leftOn.expr, rightOn.expr,
leftBy.expr, rightBy.expr, tolerance)
}
}

/**
* Returns a new Dataset with each partition sorted by the given expressions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.adaptive.LogicalQueryStage
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, MergeAsOfJoinExec}
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
Expand Down Expand Up @@ -740,6 +740,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
case logical.MergeAsOf(left, right, leftOn, rightOn, leftBy, rightBy, tolerance) =>
joins.MergeAsOfJoinExec(planLater(left), planLater(right), leftOn, rightOn,
leftBy, rightBy, tolerance) :: Nil
case _ => Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import scala.concurrent.duration._
import scala.util.control.Breaks._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.collection.BitSet


case class MergeAsOfJoinExec(
left: SparkPlan,
right: SparkPlan,
leftOn: Expression,
rightOn: Expression,
leftBy: Expression,
rightBy: Expression,
tolerance: String) extends BinaryExecNode {

override def output: Seq[Attribute] = left.output ++ right.output.map(_.withNullability((true)))

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] =
HashClusteredDistribution(Seq(leftBy)) :: HashClusteredDistribution(Seq(rightBy)) :: Nil

override def outputOrdering: Seq[SortOrder] =
getKeyOrdering(Seq(leftBy, leftOn), left.outputOrdering)

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
Seq(leftBy, leftOn).map(SortOrder(_, Ascending)) ::
Seq(rightBy, rightOn).map(SortOrder(_, Ascending)) :: Nil
}

private val emptyVal: Array[Any] = Array.fill(right.output.length)(null)
private def rDummy = InternalRow(emptyVal: _*)
private def joinedRow = new JoinedRow()

private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder])
: Seq[SortOrder] = {
val requiredOrdering = keys.map(SortOrder(_, Ascending))
if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) {
keys.zip(childOutputOrdering).map { case (key, childOrder) =>
SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key)
}
} else {
requiredOrdering
}
}

private def match_tolerance(
currLeft: (InternalRow, Iterator[InternalRow]),
currRight: (InternalRow, Iterator[InternalRow]),
tolerance: Duration,
resultProj: InternalRow => InternalRow
): Iterator[InternalRow] = {
// val joinedRow = new JoinedRow()
var rHead = if (currRight._2.hasNext) {
currRight._2.next()
} else {
InternalRow.empty
}
var rPrev = rHead.copy()

currLeft._2.map(lHead => {
breakable {
while (rHead.getInt(0) <= lHead.getInt(0)) {
// TODO make index agnostic and check by type (timestamp)
var rHeadCopy = rHead.copy()
if (currRight._2.hasNext) {
rPrev = rHeadCopy.copy()
rHeadCopy = currRight._2.next()
} else {
break
}
}
}
if (rPrev == InternalRow.empty || rPrev.getInt(0) > lHead.getInt(0)) {
resultProj(joinedRow(lHead, rDummy))
} else {
resultProj(joinedRow(lHead, rPrev))
}
}
)
}

protected override def doExecute(): RDD[InternalRow] = {

val duration = Duration(tolerance)
val inputSchema = left.output ++ right.output

left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
// val joinedRow = new JoinedRow()
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, inputSchema)
if (!leftIter.hasNext || !rightIter.hasNext) {
leftIter.map(r => resultProj(joinedRow(r, rDummy)))
} else {
val joinedRow = new JoinedRow()
val rightGroupedIterator =
GroupedIterator(rightIter, Seq(rightBy), right.output)

if (rightGroupedIterator.hasNext) {
var currRight = rightGroupedIterator.next()
val leftGroupedIterator =
GroupedIterator(leftIter, Seq(leftBy), left.output)
if (leftGroupedIterator.hasNext) {
var currLeft = leftGroupedIterator.next()
match_tolerance(currLeft, currRight, duration, resultProj)
} else {
Iterator.empty
}
} else {
Iterator.empty
}
}
}
}
}
157 changes: 157 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/MergeAsOfSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.Locale

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
import org.apache.spark.sql.execution.{BinaryExecNode, SortExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType

class MergeAsOfSuite extends QueryTest with SharedSQLContext{
import testImplicits._

setupTestData()

def statisticSizeInByte(df: DataFrame): BigInt = {
df.queryExecution.optimizedPlan.stats.sizeInBytes
}

test("basic merge") {
val df1 = Seq(
(2001, 1, 1.0),
(2001, 2, 1.1),
(2002, 1, 1.2)
).toDF("time", "id", "v")

val df2 = Seq(
(2001, 1, 4),
(2001, 2, 5),
).toDF("time", "id", "v2")

val res = df1.mergeAsOf(df2, df1("time"), df2("time"), df1("id"), df2("id"))

val expected = Seq(
(2001, 1, 1.0, 4),
(2002, 1, 1.2, 4),
(2001, 2, 1.1, 5)
).toDF("time", "id", "v", "v2")

assert(res.collect() === expected.collect())

val res2 = df1.select("time", "id").mergeAsOf(df2.withColumn("v3", df2("v2") * 3 cast "Int"), df1("time"), df2("time"), df1("id"), df2("id"))

val expected2 = Seq(
(2001, 1, 4, 12),
(2002, 1, 4, 12),
(2001, 2, 5, 15)
).toDF("time", "id", "v2", "v3")

assert(res2.collect() === expected2.collect())
}

test("default merge_asof") {
val quotes = Seq(
(23, "GOOG", 720.50, 720.93),
(23, "MSFT", 51.95, 51.96),
(30, "MSFT", 51.97, 51.98),
(41, "MSFT", 51.99, 52.00),
(48, "GOOG", 720.50, 720.93),
(49, "AAPL", 97.99, 98.01),
(72, "GOOG", 720.50, 720.88),
(75, "MSFT", 52.01, 52.03)
).toDF("time", "ticker", "bid", "ask")

val trades = Seq(
(23, "MSFT", 51.95, 75),
(38, "MSFT", 51.95, 155),
(48, "GOOG", 720.77, 100),
(48, "GOOG", 720.92, 100),
(48, "AAPL", 98.00, 100)
).toDF("time", "ticker", "price", "quantity")

val res = trades.mergeAsOf(quotes, trades("time"), quotes("time"), trades("ticker"), quotes("ticker"), "2ms")

val expected = Seq(
(23, "MSFT", 51.95, 75, 51.95, 51.96),
(38, "MSFT", 51.95, 155, 51.97, 51.98),
(48, "GOOG", 720.77, 100, 720.5, 720.93),
(48, "GOOG", 720.92, 100, 720.5, 720.93),
(48, "AAPL", 98.0, 100, 0.0, 0.0)
).toDF("time", "ticker", "price", "quantity", "bid", "ask")

res.show()
expected.show()
// println(res.collect() === expected.collect()) // TODO sort results by key
}

test("partial key mismatch") {
val df1 = Seq(
(2001, 1, 1.0),
(2001, 2, 1.1),
(2002, 1, 1.2)
).toDF("time", "id", "v")

val df2 = Seq(
(2001, 1, 5),
(2001, 4, 4),
).toDF("time", "id", "v2")

val res = df1.mergeAsOf(df2, df1("time"), df2("time"), df1("id"), df2("id"))

val expected = Seq(
(2001, 1, 1.0, 5),
(2002, 1, 1.2, 5),
(2001, 2, 1.1, 0)
).toDF("time", "id", "v", "v2")

assert(res.collect() === expected.collect())
}

test("complete key mismatch") {
val df1 = Seq(
(2001, 1, 1.0),
(2001, 2, 1.1),
(2002, 1, 1.2)
).toDF("time", "id", "v")

val df2 = Seq(
(2001, 3, 5),
(2001, 4, 4),
).toDF("time", "id", "v2")

val res = df1.mergeAsOf(df2, df1("time"), df2("time"), df1("id"), df2("id"))

val expected = Seq(
(2001, 1, 1.0, 0),
(2002, 1, 1.2, 0),
(2001, 2, 1.1, 0)
).toDF("time", "id", "v", "v2")

assert(res.collect() === expected.collect())
}
}

0 comments on commit 3102507

Please sign in to comment.