diff --git a/scio-core/src/main/scala/com/spotify/scio/values/PairSCollectionFunctions.scala b/scio-core/src/main/scala/com/spotify/scio/values/PairSCollectionFunctions.scala index 4e944a919b..4ac15a1003 100644 --- a/scio-core/src/main/scala/com/spotify/scio/values/PairSCollectionFunctions.scala +++ b/scio-core/src/main/scala/com/spotify/scio/values/PairSCollectionFunctions.scala @@ -67,7 +67,7 @@ class PairSCollectionFunctions[K, V](val self: SCollection[(K, V)]) { )(f: KV[K, UI] => (K, UO)): SCollection[(K, UO)] = { self.transform( _.withName("TupleToKv").toKV - .applyTransform(t) + .applyTransform(t.getName, t) .withName("KvToTuple") .map(f) ) diff --git a/scio-core/src/main/scala/com/spotify/scio/values/SCollection.scala b/scio-core/src/main/scala/com/spotify/scio/values/SCollection.scala index a625a1c01f..eba6bdc00c 100644 --- a/scio-core/src/main/scala/com/spotify/scio/values/SCollection.scala +++ b/scio-core/src/main/scala/com/spotify/scio/values/SCollection.scala @@ -202,16 +202,13 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] { name: Option[String], transform: PTransform[_ >: PCollection[T], PCollection[U]] ): SCollection[U] = { - val t = - if ( - (classOf[Combine.Globally[T, U]] isAssignableFrom transform.getClass) - && ScioUtil.isWindowed(this) - ) { - // In case PCollection is windowed - transform.asInstanceOf[Combine.Globally[T, U]].withoutDefaults() - } else { - transform - } + val isCombineGlobally = classOf[Combine.Globally[T, U]].isAssignableFrom(transform.getClass) + val t = if (isCombineGlobally && ScioUtil.isWindowed(this)) { + // In case PCollection is windowed + transform.asInstanceOf[Combine.Globally[T, U]].withoutDefaults() + } else { + transform + } context.wrap(this.applyInternal(name, t)) } diff --git a/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithFanout.scala b/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithFanout.scala index 23bf4b04b6..a5bf1a4b4a 100644 --- a/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithFanout.scala +++ b/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithFanout.scala @@ -50,13 +50,21 @@ class SCollectionWithFanout[T] private[values] (coll: SCollection[T], fanout: In /** [[SCollection.aggregate[A,U]* SCollection.aggregate]] with fan out. */ def aggregate[A: Coder, U: Coder](aggregator: Aggregator[T, A, U]): SCollection[U] = { val a = aggregator // defeat closure - coll.transform(_.map(a.prepare).sum(a.semigroup).map(a.present)) + coll.transform { in => + new SCollectionWithFanout(in.map(a.prepare), fanout) + .sum(a.semigroup) + .map(a.present) + } } /** [[SCollection.aggregate[A,U]* SCollection.aggregate]] with fan out. */ def aggregate[A: Coder, U: Coder](aggregator: MonoidAggregator[T, A, U]): SCollection[U] = { val a = aggregator // defeat closure - coll.transform(_.map(a.prepare).fold(a.monoid).map(a.present)) + coll.transform { in => + new SCollectionWithFanout(in.map(a.prepare), fanout) + .fold(a.monoid) + .map(a.present) + } } /** [[SCollection.combine]] with fan out. */ diff --git a/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithHotKeyFanout.scala b/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithHotKeyFanout.scala index 36242a839a..5459a857df 100644 --- a/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithHotKeyFanout.scala +++ b/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithHotKeyFanout.scala @@ -61,21 +61,19 @@ class SCollectionWithHotKeyFanout[K, V] private[values] ( */ def aggregateByKey[U: Coder]( zeroValue: U - )(seqOp: (U, V) => U, combOp: (U, U) => U): SCollection[(K, U)] = - self.applyPerKey( - withFanout(Combine.perKey(Functions.aggregateFn(context, zeroValue)(seqOp, combOp))) - )( - kvToTuple - ) + )(seqOp: (U, V) => U, combOp: (U, U) => U): SCollection[(K, U)] = { + val cmb = Combine.perKey[K, V, U](Functions.aggregateFn(context, zeroValue)(seqOp, combOp)) + self.applyPerKey(withFanout(cmb))(kvToTuple) + } /** * [[PairSCollectionFunctions.aggregateByKey[A,U]* PairSCollectionFunctions.aggregateByKey]] with * hot key fanout. */ def aggregateByKey[A: Coder, U: Coder](aggregator: Aggregator[V, A, U]): SCollection[(K, U)] = - self.self.context.wrap(self.self.internal).transform { in => + self.self.transform { in => val a = aggregator // defeat closure - in.mapValues(a.prepare) + new SCollectionWithHotKeyFanout(context, in.mapValues(a.prepare), hotKeyFanout) .sumByKey(a.semigroup) .mapValues(a.present) } @@ -86,13 +84,14 @@ class SCollectionWithHotKeyFanout[K, V] private[values] ( */ def aggregateByKey[A: Coder, U: Coder]( aggregator: MonoidAggregator[V, A, U] - ): SCollection[(K, U)] = - self.self.context.wrap(self.self.internal).transform { in => + ): SCollection[(K, U)] = { + self.self.transform { in => val a = aggregator // defeat closure - in.mapValues(a.prepare) + new SCollectionWithHotKeyFanout(context, in.mapValues(a.prepare), hotKeyFanout) .foldByKey(a.monoid) .mapValues(a.present) } + } /** [[PairSCollectionFunctions.combineByKey]] with hot key fanout. */ def combineByKey[C: Coder]( diff --git a/scio-test/src/test/scala/com/spotify/scio/values/NamedTransformTest.scala b/scio-test/src/test/scala/com/spotify/scio/values/NamedTransformTest.scala index 46911fe562..69483f12c4 100644 --- a/scio-test/src/test/scala/com/spotify/scio/values/NamedTransformTest.scala +++ b/scio-test/src/test/scala/com/spotify/scio/values/NamedTransformTest.scala @@ -22,6 +22,9 @@ import com.spotify.scio.util.MultiJoin import org.apache.beam.sdk.Pipeline import org.apache.beam.sdk.runners.TransformHierarchy import org.apache.beam.sdk.values.PCollection +import org.scalatest.Assertion + +import scala.collection.mutable object SimpleJob { import com.spotify.scio._ @@ -34,7 +37,47 @@ object SimpleJob { } } -class NamedTransformTest extends PipelineSpec { +trait NamedTransformSpec extends PipelineSpec { + def assertTransformNameStartsWith(p: PCollectionWrapper[_], tfName: String): Assertion = { + val visitor = new AssertTransformNameVisitor(p.internal, tfName) + p.context.pipeline.traverseTopologically(visitor) + visitor.nodeFullName should startWith regex tfName + } + + def assertGraphContainsStep(p: PCollectionWrapper[_], tfName: String): Assertion = { + val visitor = new NameAccumulatingVisitor() + p.context.pipeline.traverseTopologically(visitor) + withClue(s"All nodes: ${visitor.nodes.sorted.mkString("", "\n", "\n")}") { + visitor.nodes.flatMap(_.split('/')).toSet should contain(tfName) + } + } + + def assertGraphContainsStepRegex(p: PCollectionWrapper[_], tfNameRegex: String): Assertion = { + val visitor = new NameAccumulatingVisitor() + p.context.pipeline.traverseTopologically(visitor) + val allNodes = visitor.nodes.sorted.mkString("", "\n", "\n") + withClue(s"$tfNameRegex did not match a step in any of the following nodes: $allNodes") { + visitor.nodes.flatMap(_.split('/')).toSet.filter(_.matches(tfNameRegex)) should not be empty + } + } + + class AssertTransformNameVisitor(pcoll: PCollection[_], tfName: String) + extends Pipeline.PipelineVisitor.Defaults { + val prefix: List[String] = tfName.split("[(/]").toList + var nodeFullName = "" + + override def visitPrimitiveTransform(node: TransformHierarchy#Node): Unit = + if (node.getOutputs.containsValue(pcoll)) nodeFullName = node.getFullName + } + + class NameAccumulatingVisitor extends Pipeline.PipelineVisitor.Defaults { + var nodes: mutable.ListBuffer[String] = mutable.ListBuffer.empty[String] + override def visitPrimitiveTransform(node: TransformHierarchy#Node): Unit = + nodes.append(node.getFullName) + } +} + +class NamedTransformTest extends NamedTransformSpec { "ScioContext" should "support custom transform name" in { runWithContext { sc => val p = sc.withName("ReadInput").parallelize(Seq("a", "b", "c")) @@ -216,27 +259,4 @@ class NamedTransformTest extends PipelineSpec { userNamed should be("UserNamed") } } - - private def assertTransformNameStartsWith(p: PCollectionWrapper[_], tfName: String) = { - val visitor = new AssertTransformNameVisitor(p.internal, tfName) - p.context.pipeline.traverseTopologically(visitor) - visitor.nodeFullName should startWith regex tfName - } - - private class AssertTransformNameVisitor(pcoll: PCollection[_], tfName: String) - extends Pipeline.PipelineVisitor.Defaults { - val prefix: List[String] = tfName.split("[(/]").toList - var success = false - var nodeFullName = "" - - override def visitPrimitiveTransform(node: TransformHierarchy#Node): Unit = - if (node.getOutputs.containsValue(pcoll)) { - nodeFullName = node.getFullName - success = node.getFullName - .split("[(/]") - .toList - .take(prefix.length) - .equals(prefix) - } - } } diff --git a/scio-test/src/test/scala/com/spotify/scio/values/SCollectionWithFanoutTest.scala b/scio-test/src/test/scala/com/spotify/scio/values/SCollectionWithFanoutTest.scala index 890da80521..70b364c43c 100644 --- a/scio-test/src/test/scala/com/spotify/scio/values/SCollectionWithFanoutTest.scala +++ b/scio-test/src/test/scala/com/spotify/scio/values/SCollectionWithFanoutTest.scala @@ -17,12 +17,10 @@ package com.spotify.scio.values -import com.spotify.scio.testing.PipelineSpec import com.twitter.algebird.{Aggregator, Semigroup} - import com.spotify.scio.coders.Coder -class SCollectionWithFanoutTest extends PipelineSpec { +class SCollectionWithFanoutTest extends NamedTransformSpec { "SCollectionWithFanout" should "support aggregate()" in { runWithContext { sc => val p = sc.parallelize(1 to 100).withFanout(10) @@ -73,4 +71,43 @@ class SCollectionWithFanoutTest extends PipelineSpec { sum(1 to 100: _*) should containSingleValue(5050) } } + + private def shouldFanOut[T](fn: SCollectionWithFanout[Int] => SCollection[T]) = { + runWithContext { sc => + val p = fn(sc.parallelize(1 to 100).withFanout(10)) + assertGraphContainsStepRegex(p, "Combine\\.perKeyWithFanout\\([^)]*\\)") + } + } + + it should "fan out with aggregate(zeroValue)(seqOp)" in { + shouldFanOut(_.aggregate(0.0)(_ + _, _ + _)) + } + + it should "fan out with aggregate(Aggregator)" in { + shouldFanOut(_.aggregate(Aggregator.max[Int])) + } + + it should "fan out with aggregate(MonoidAggregator)" in { + shouldFanOut(_.aggregate(Aggregator.immutableSortedReverseTake[Int](5))) + } + + it should "fan out with combine()" in { + shouldFanOut(_.combine(_.toDouble)(_ + _)(_ + _)) + } + + it should "fan out with fold(zeroValue)(op)" in { + shouldFanOut(_.fold(0)(_ + _)) + } + + it should "fan out with fold(Monoid)" in { + shouldFanOut(_.fold) + } + + it should "fan out with reduce()" in { + shouldFanOut(_.reduce(_ + _)) + } + + it should "fan out with sum()" in { + shouldFanOut(_.sum) + } } diff --git a/scio-test/src/test/scala/com/spotify/scio/values/SCollectionWithHotKeyFanoutTest.scala b/scio-test/src/test/scala/com/spotify/scio/values/SCollectionWithHotKeyFanoutTest.scala index b42e4da2e9..d6b8b67283 100644 --- a/scio-test/src/test/scala/com/spotify/scio/values/SCollectionWithHotKeyFanoutTest.scala +++ b/scio-test/src/test/scala/com/spotify/scio/values/SCollectionWithHotKeyFanoutTest.scala @@ -17,10 +17,9 @@ package com.spotify.scio.values -import com.spotify.scio.testing.PipelineSpec import com.twitter.algebird.Aggregator -class SCollectionWithHotKeyFanoutTest extends PipelineSpec { +class SCollectionWithHotKeyFanoutTest extends NamedTransformSpec { "SCollectionWithHotKeyFanout" should "support aggregateByKey()" in { runWithContext { sc => val p = sc.parallelize(1 to 100).map(("a", _)) ++ sc @@ -93,4 +92,48 @@ class SCollectionWithHotKeyFanoutTest extends PipelineSpec { r2 should containInAnyOrder(Seq(("a", 1), ("b", 4), ("c", 5050))) } } + + private def shouldFanOut[T]( + fn: SCollectionWithHotKeyFanout[String, Int] => SCollection[T] + ) = { + runWithContext { sc => + val p1 = sc.parallelize(1 to 100).map(("a", _)) + val p2 = sc.parallelize(1 to 10).map(("b", _)) + val p = (p1 ++ p2).withHotKeyFanout(10) + assertGraphContainsStepRegex(fn(p), "Combine\\.perKeyWithFanout\\([^)]*\\)") + } + } + + it should "fan out with aggregateByKey(zeroValue)(seqOp)" in { + shouldFanOut(_.aggregateByKey(0.0)(_ + _, _ + _)) + } + + it should "fan out with aggregateByKey(Aggregator)" in { + shouldFanOut(_.aggregateByKey(Aggregator.max[Int])) + } + + it should "fan out with aggregateByKey(MonoidAggregator)" in { + shouldFanOut(_.aggregateByKey(Aggregator.immutableSortedReverseTake[Int](5))) + } + + it should "fan out with combineByKey()" in { + shouldFanOut(_.combineByKey(_.toDouble)(_ + _)(_ + _)) + } + + it should "fan out with foldByKey(zeroValue)(op)" in { + shouldFanOut(_.foldByKey(0)(_ + _)) + } + + it should "fan out with foldByKey(Monoid)" in { + shouldFanOut(_.foldByKey) + } + + it should "fan out with reduceByKey()" in { + shouldFanOut(_.reduceByKey(_ + _)) + } + + it should "fan out with sumByKey()" in { + shouldFanOut(_.sumByKey) + } + }