Skip to content

Commit

Permalink
Introduction of timeout in passes (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
khemrajrathore authored Jul 30, 2024
1 parent de6269f commit 41b751e
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 2 deletions.
141 changes: 139 additions & 2 deletions codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ package io.shiftleft.passes

import com.google.protobuf.GeneratedMessageV3
import io.shiftleft.SerializedCpg
import io.shiftleft.codepropertygraph.generated.Cpg
import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder}
import io.shiftleft.utils.StatsLogger
import org.slf4j.{Logger, LoggerFactory, MDC}
import overflowdb.BatchedUpdate

import java.util.concurrent.{TimeUnit, TimeoutException}
import java.util.function.{BiConsumer, Supplier}
import scala.annotation.nowarn
import scala.concurrent.duration.DurationLong
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration.{Duration, DurationLong}
import scala.util.{Failure, Success, Try}

/* CpgPass
Expand Down Expand Up @@ -55,6 +57,56 @@ abstract class CpgPass(cpg: Cpg, outName: String = "", keyPool: Option[KeyPool]
* methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct
* passes eagerly, and releases them only when the entire chain has run.
* */
abstract class ForkJoinParallelCpgPassWithTimeout[T <: AnyRef](
cpg: Cpg,
@nowarn outName: String = "",
keyPool: Option[KeyPool] = None,
timeout: Long = -1
) extends NewStyleCpgPassBaseWithTimeout[T](timeout) {

override def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit = {
baseLogger.info(s"Start of pass: $name")
StatsLogger.initiateNewStage(getClass.getSimpleName, Some(name), getClass.getSuperclass.getSimpleName)
val nanosStart = System.nanoTime()
var nParts = 0
var nanosBuilt = -1L
var nDiff = -1
var nDiffT = -1
try {
val diffGraph = Cpg.newDiffGraphBuilder
nParts = runWithBuilder(diffGraph)
nanosBuilt = System.nanoTime()
nDiff = diffGraph.size()

nDiffT = overflowdb.BatchedUpdate
.applyDiff(cpg.graph, diffGraph, keyPool.getOrElse(null), null)
.transitiveModifications()

} catch {
case exc: Exception =>
baseLogger.error(s"Pass ${name} failed", exc)
throw exc
} finally {
try {
finish()
} finally {
// the nested finally is somewhat ugly -- but we promised to clean up with finish(), we want to include finish()
// in the reported timings, and we must have our final log message if finish() throws
val nanosStop = System.nanoTime()
val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1)
val serializationString = if (serializedCpg != null && !serializedCpg.isEmpty) {
" Diff serialized and stored."
} else ""
baseLogger.info(
f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes committed from ${nParts}%d parts.${serializationString}%s"
)
StatsLogger.endLastStage()
}
}
}

}

abstract class ForkJoinParallelCpgPass[T <: AnyRef](
cpg: Cpg,
@nowarn outName: String = "",
Expand Down Expand Up @@ -168,6 +220,91 @@ abstract class NewStyleCpgPassBase[T <: AnyRef] extends CpgPassBase {
}
}

abstract class NewStyleCpgPassBaseWithTimeout[T <: AnyRef](timeout: Long) extends CpgPassBase {
type DiffGraphBuilder = overflowdb.BatchedUpdate.DiffGraphBuilder

// generate Array of parts that can be processed in parallel
def generateParts(): Array[? <: AnyRef]

// setup large data structures, acquire external resources
def init(): Unit = {}

// release large data structures and external resources
def finish(): Unit = {}

// main function: add desired changes to builder
def runOnPart(builder: DiffGraphBuilder, part: T): Unit

// Override this to disable parallelism of passes. Useful for debugging.
def isParallel: Boolean = true

override def createAndApply(): Unit = createApplySerializeAndStore(null)

override def runWithBuilder(externalBuilder: BatchedUpdate.DiffGraphBuilder): Int = {
try {
init()
val parts = generateParts()
val nParts = parts.size
nParts match {
case 0 =>
case 1 =>
runOnPart(externalBuilder, parts(0).asInstanceOf[T])
case _ =>
if (!isParallel) {
val diff = java.util.Arrays
.stream(parts)
.sequential()
.collect(
new Supplier[DiffGraphBuilder] {
override def get(): DiffGraphBuilder =
Cpg.newDiffGraphBuilder
},
new BiConsumer[DiffGraphBuilder, AnyRef] {
override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit =
runOnPart(builder, part.asInstanceOf[T])
},
new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] {
override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit =
leftBuilder.absorb(rightBuilder)
}
)
externalBuilder.absorb(diff)
} else {
implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
val stopAt = System.currentTimeMillis() + timeout * 1000
var exitByTimeout = false
val diffGraphAccumulator = Cpg.newDiffGraphBuilder

val futures = parts.map { part =>
val future = Future {
val diffGraphBuilder = Cpg.newDiffGraphBuilder
runOnPart(diffGraphBuilder, part.asInstanceOf[T])
diffGraphBuilder
}
future
}

futures.foreach { future =>
val currentTimeInMs = System.currentTimeMillis()
val duration =
if timeout == -1 then Duration.Inf else Duration(stopAt - currentTimeInMs, TimeUnit.MILLISECONDS)
Try(Await.result(future, duration)) match
case Failure(exception: TimeoutException) =>
baseLogger.debug(s"Timeout occurred for passed timeout value of ${timeout} seconds")
case Failure(e) => throw e
case Success(diffGraphBuilder) =>
diffGraphAccumulator.absorb(diffGraphBuilder)
}
externalBuilder.absorb(diffGraphAccumulator)
}
}
nParts
} finally {
finish()
}
}
}

object CpgPassBase {
private val baseLogger: Logger = LoggerFactory.getLogger(classOf[CpgPassBase])
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,37 @@ class ParallelCpgPassNewTests extends AnyWordSpec with Matchers {
}

}

class ForkJoinParallelCpgPassNewTests extends AnyWordSpec with Matchers {

private object Fixture {
def apply(keyPools: Option[Iterator[KeyPool]] = None, timeout: Long = -1)(f: (Cpg, CpgPassBase) => Unit): Unit = {
val cpg = Cpg.empty
val pool = keyPools.flatMap(_.nextOption())
class MyPass(cpg: Cpg)
extends ForkJoinParallelCpgPassWithTimeout[String](cpg, "MyPass", pool, timeout = timeout) {
override def generateParts(): Array[String] = Range(1, 101).map(_.toString).toArray

override def runOnPart(diffGraph: DiffGraphBuilder, part: String): Unit = {
Thread.sleep(1000)
diffGraph.addNode(NewFile().name(part))
}
}
val pass = new MyPass(cpg)
f(cpg, pass)
}
}

"ForkJoinParallelPassWithTimeout" should {
"generate partial result in case of timeout" in Fixture(timeout = 2) { (cpg, pass) =>
pass.createAndApply()
assert(cpg.graph.nodes.map(_.property(Properties.Name)).toList.size != 100)
}

"generate complete result without timeout" in Fixture() { (cpg, pass) =>
pass.createAndApply()
assert(cpg.graph.nodes.map(_.property(Properties.Name)).toList.size == 100)
}
}

}

0 comments on commit 41b751e

Please sign in to comment.