Skip to content

Commit

Permalink
Don't forget to fan out (#4740)
Browse files Browse the repository at this point in the history
* Don't forget to fan out

* Add tests, simplify SCollectionWithHotKeyFanout, fix applyTransform transform naming

* Fix regex

* Reduce diff

* Revert mima-breaking change

---------

Co-authored-by: Michel Davit <[email protected]>
  • Loading branch information
kellen and RustedBones authored Mar 16, 2023
1 parent 52c611e commit 18fe3ad
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class PairSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
)(f: KV[K, UI] => (K, UO)): SCollection[(K, UO)] = {
self.transform(
_.withName("TupleToKv").toKV
.applyTransform(t)
.applyTransform(t.getName, t)
.withName("KvToTuple")
.map(f)
)
Expand Down
17 changes: 7 additions & 10 deletions scio-core/src/main/scala/com/spotify/scio/values/SCollection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,13 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
name: Option[String],
transform: PTransform[_ >: PCollection[T], PCollection[U]]
): SCollection[U] = {
val t =
if (
(classOf[Combine.Globally[T, U]] isAssignableFrom transform.getClass)
&& ScioUtil.isWindowed(this)
) {
// In case PCollection is windowed
transform.asInstanceOf[Combine.Globally[T, U]].withoutDefaults()
} else {
transform
}
val isCombineGlobally = classOf[Combine.Globally[T, U]].isAssignableFrom(transform.getClass)
val t = if (isCombineGlobally && ScioUtil.isWindowed(this)) {
// In case PCollection is windowed
transform.asInstanceOf[Combine.Globally[T, U]].withoutDefaults()
} else {
transform
}
context.wrap(this.applyInternal(name, t))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,21 @@ class SCollectionWithFanout[T] private[values] (coll: SCollection[T], fanout: In
/** [[SCollection.aggregate[A,U]* SCollection.aggregate]] with fan out. */
def aggregate[A: Coder, U: Coder](aggregator: Aggregator[T, A, U]): SCollection[U] = {
val a = aggregator // defeat closure
coll.transform(_.map(a.prepare).sum(a.semigroup).map(a.present))
coll.transform { in =>
new SCollectionWithFanout(in.map(a.prepare), fanout)
.sum(a.semigroup)
.map(a.present)
}
}

/** [[SCollection.aggregate[A,U]* SCollection.aggregate]] with fan out. */
def aggregate[A: Coder, U: Coder](aggregator: MonoidAggregator[T, A, U]): SCollection[U] = {
val a = aggregator // defeat closure
coll.transform(_.map(a.prepare).fold(a.monoid).map(a.present))
coll.transform { in =>
new SCollectionWithFanout(in.map(a.prepare), fanout)
.fold(a.monoid)
.map(a.present)
}
}

/** [[SCollection.combine]] with fan out. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,19 @@ class SCollectionWithHotKeyFanout[K, V] private[values] (
*/
def aggregateByKey[U: Coder](
zeroValue: U
)(seqOp: (U, V) => U, combOp: (U, U) => U): SCollection[(K, U)] =
self.applyPerKey(
withFanout(Combine.perKey(Functions.aggregateFn(context, zeroValue)(seqOp, combOp)))
)(
kvToTuple
)
)(seqOp: (U, V) => U, combOp: (U, U) => U): SCollection[(K, U)] = {
val cmb = Combine.perKey[K, V, U](Functions.aggregateFn(context, zeroValue)(seqOp, combOp))
self.applyPerKey(withFanout(cmb))(kvToTuple)
}

/**
* [[PairSCollectionFunctions.aggregateByKey[A,U]* PairSCollectionFunctions.aggregateByKey]] with
* hot key fanout.
*/
def aggregateByKey[A: Coder, U: Coder](aggregator: Aggregator[V, A, U]): SCollection[(K, U)] =
self.self.context.wrap(self.self.internal).transform { in =>
self.self.transform { in =>
val a = aggregator // defeat closure
in.mapValues(a.prepare)
new SCollectionWithHotKeyFanout(context, in.mapValues(a.prepare), hotKeyFanout)
.sumByKey(a.semigroup)
.mapValues(a.present)
}
Expand All @@ -86,13 +84,14 @@ class SCollectionWithHotKeyFanout[K, V] private[values] (
*/
def aggregateByKey[A: Coder, U: Coder](
aggregator: MonoidAggregator[V, A, U]
): SCollection[(K, U)] =
self.self.context.wrap(self.self.internal).transform { in =>
): SCollection[(K, U)] = {
self.self.transform { in =>
val a = aggregator // defeat closure
in.mapValues(a.prepare)
new SCollectionWithHotKeyFanout(context, in.mapValues(a.prepare), hotKeyFanout)
.foldByKey(a.monoid)
.mapValues(a.present)
}
}

/** [[PairSCollectionFunctions.combineByKey]] with hot key fanout. */
def combineByKey[C: Coder](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import com.spotify.scio.util.MultiJoin
import org.apache.beam.sdk.Pipeline
import org.apache.beam.sdk.runners.TransformHierarchy
import org.apache.beam.sdk.values.PCollection
import org.scalatest.Assertion

import scala.collection.mutable

object SimpleJob {
import com.spotify.scio._
Expand All @@ -34,7 +37,47 @@ object SimpleJob {
}
}

class NamedTransformTest extends PipelineSpec {
trait NamedTransformSpec extends PipelineSpec {
def assertTransformNameStartsWith(p: PCollectionWrapper[_], tfName: String): Assertion = {
val visitor = new AssertTransformNameVisitor(p.internal, tfName)
p.context.pipeline.traverseTopologically(visitor)
visitor.nodeFullName should startWith regex tfName
}

def assertGraphContainsStep(p: PCollectionWrapper[_], tfName: String): Assertion = {
val visitor = new NameAccumulatingVisitor()
p.context.pipeline.traverseTopologically(visitor)
withClue(s"All nodes: ${visitor.nodes.sorted.mkString("", "\n", "\n")}") {
visitor.nodes.flatMap(_.split('/')).toSet should contain(tfName)
}
}

def assertGraphContainsStepRegex(p: PCollectionWrapper[_], tfNameRegex: String): Assertion = {
val visitor = new NameAccumulatingVisitor()
p.context.pipeline.traverseTopologically(visitor)
val allNodes = visitor.nodes.sorted.mkString("", "\n", "\n")
withClue(s"$tfNameRegex did not match a step in any of the following nodes: $allNodes") {
visitor.nodes.flatMap(_.split('/')).toSet.filter(_.matches(tfNameRegex)) should not be empty
}
}

class AssertTransformNameVisitor(pcoll: PCollection[_], tfName: String)
extends Pipeline.PipelineVisitor.Defaults {
val prefix: List[String] = tfName.split("[(/]").toList
var nodeFullName = "<unknown>"

override def visitPrimitiveTransform(node: TransformHierarchy#Node): Unit =
if (node.getOutputs.containsValue(pcoll)) nodeFullName = node.getFullName
}

class NameAccumulatingVisitor extends Pipeline.PipelineVisitor.Defaults {
var nodes: mutable.ListBuffer[String] = mutable.ListBuffer.empty[String]
override def visitPrimitiveTransform(node: TransformHierarchy#Node): Unit =
nodes.append(node.getFullName)
}
}

class NamedTransformTest extends NamedTransformSpec {
"ScioContext" should "support custom transform name" in {
runWithContext { sc =>
val p = sc.withName("ReadInput").parallelize(Seq("a", "b", "c"))
Expand Down Expand Up @@ -216,27 +259,4 @@ class NamedTransformTest extends PipelineSpec {
userNamed should be("UserNamed")
}
}

private def assertTransformNameStartsWith(p: PCollectionWrapper[_], tfName: String) = {
val visitor = new AssertTransformNameVisitor(p.internal, tfName)
p.context.pipeline.traverseTopologically(visitor)
visitor.nodeFullName should startWith regex tfName
}

private class AssertTransformNameVisitor(pcoll: PCollection[_], tfName: String)
extends Pipeline.PipelineVisitor.Defaults {
val prefix: List[String] = tfName.split("[(/]").toList
var success = false
var nodeFullName = "<unknown>"

override def visitPrimitiveTransform(node: TransformHierarchy#Node): Unit =
if (node.getOutputs.containsValue(pcoll)) {
nodeFullName = node.getFullName
success = node.getFullName
.split("[(/]")
.toList
.take(prefix.length)
.equals(prefix)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

package com.spotify.scio.values

import com.spotify.scio.testing.PipelineSpec
import com.twitter.algebird.{Aggregator, Semigroup}

import com.spotify.scio.coders.Coder

class SCollectionWithFanoutTest extends PipelineSpec {
class SCollectionWithFanoutTest extends NamedTransformSpec {
"SCollectionWithFanout" should "support aggregate()" in {
runWithContext { sc =>
val p = sc.parallelize(1 to 100).withFanout(10)
Expand Down Expand Up @@ -73,4 +71,43 @@ class SCollectionWithFanoutTest extends PipelineSpec {
sum(1 to 100: _*) should containSingleValue(5050)
}
}

private def shouldFanOut[T](fn: SCollectionWithFanout[Int] => SCollection[T]) = {
runWithContext { sc =>
val p = fn(sc.parallelize(1 to 100).withFanout(10))
assertGraphContainsStepRegex(p, "Combine\\.perKeyWithFanout\\([^)]*\\)")
}
}

it should "fan out with aggregate(zeroValue)(seqOp)" in {
shouldFanOut(_.aggregate(0.0)(_ + _, _ + _))
}

it should "fan out with aggregate(Aggregator)" in {
shouldFanOut(_.aggregate(Aggregator.max[Int]))
}

it should "fan out with aggregate(MonoidAggregator)" in {
shouldFanOut(_.aggregate(Aggregator.immutableSortedReverseTake[Int](5)))
}

it should "fan out with combine()" in {
shouldFanOut(_.combine(_.toDouble)(_ + _)(_ + _))
}

it should "fan out with fold(zeroValue)(op)" in {
shouldFanOut(_.fold(0)(_ + _))
}

it should "fan out with fold(Monoid)" in {
shouldFanOut(_.fold)
}

it should "fan out with reduce()" in {
shouldFanOut(_.reduce(_ + _))
}

it should "fan out with sum()" in {
shouldFanOut(_.sum)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

package com.spotify.scio.values

import com.spotify.scio.testing.PipelineSpec
import com.twitter.algebird.Aggregator

class SCollectionWithHotKeyFanoutTest extends PipelineSpec {
class SCollectionWithHotKeyFanoutTest extends NamedTransformSpec {
"SCollectionWithHotKeyFanout" should "support aggregateByKey()" in {
runWithContext { sc =>
val p = sc.parallelize(1 to 100).map(("a", _)) ++ sc
Expand Down Expand Up @@ -93,4 +92,48 @@ class SCollectionWithHotKeyFanoutTest extends PipelineSpec {
r2 should containInAnyOrder(Seq(("a", 1), ("b", 4), ("c", 5050)))
}
}

private def shouldFanOut[T](
fn: SCollectionWithHotKeyFanout[String, Int] => SCollection[T]
) = {
runWithContext { sc =>
val p1 = sc.parallelize(1 to 100).map(("a", _))
val p2 = sc.parallelize(1 to 10).map(("b", _))
val p = (p1 ++ p2).withHotKeyFanout(10)
assertGraphContainsStepRegex(fn(p), "Combine\\.perKeyWithFanout\\([^)]*\\)")
}
}

it should "fan out with aggregateByKey(zeroValue)(seqOp)" in {
shouldFanOut(_.aggregateByKey(0.0)(_ + _, _ + _))
}

it should "fan out with aggregateByKey(Aggregator)" in {
shouldFanOut(_.aggregateByKey(Aggregator.max[Int]))
}

it should "fan out with aggregateByKey(MonoidAggregator)" in {
shouldFanOut(_.aggregateByKey(Aggregator.immutableSortedReverseTake[Int](5)))
}

it should "fan out with combineByKey()" in {
shouldFanOut(_.combineByKey(_.toDouble)(_ + _)(_ + _))
}

it should "fan out with foldByKey(zeroValue)(op)" in {
shouldFanOut(_.foldByKey(0)(_ + _))
}

it should "fan out with foldByKey(Monoid)" in {
shouldFanOut(_.foldByKey)
}

it should "fan out with reduceByKey()" in {
shouldFanOut(_.reduceByKey(_ + _))
}

it should "fan out with sumByKey()" in {
shouldFanOut(_.sumByKey)
}

}

0 comments on commit 18fe3ad

Please sign in to comment.