Skip to content

Commit

Permalink
Enable UDF evaluation with Arrow using stream format to load as Panda…
Browse files Browse the repository at this point in the history
…s Series, modified PythonRDD to support this and maintain backwards compatibility
  • Loading branch information
BryanCutler committed May 22, 2017
1 parent 57a2cde commit 45db636
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 77 deletions.
208 changes: 137 additions & 71 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ object CommandPythonFunctionType extends PythonFunctionType{ override val value
object SqlUdfPythonFunctionType extends PythonFunctionType{ override val value = 1 }
object PandasUdfPythonFunctionType extends PythonFunctionType{ override val value = 2 }

/**
* Interface that can be used when building an iterator to read data from Python
*/
private[spark] trait PythonReadInterface {
def getDataStream: DataInputStream
def readLengthFromPython(): Int
def readFooter(): Unit
}

/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
Expand All @@ -123,10 +132,11 @@ private[spark] class PythonRunner(
// TODO: support accumulator in multiple UDF
private val accumulator = funcs.head.funcs.head.accumulator

def compute(
inputIterator: Iterator[_],
def process[U](
dataWriteBlock: DataOutputStream => Unit,
dataReadBuilder: PythonReadInterface => Iterator[U],
partitionIndex: Int,
context: TaskContext): Iterator[Array[Byte]] = {
context: TaskContext): Iterator[U] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
Expand All @@ -139,7 +149,7 @@ private[spark] class PythonRunner(
@volatile var released = false

// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)
val writerThread = new WriterThread(env, worker, dataWriteBlock, partitionIndex, context)

context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
Expand All @@ -156,79 +166,29 @@ private[spark] class PythonRunner(
writerThread.start()
new MonitorThread(env, worker, context).start()

// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
val stdoutIterator = new Iterator[Array[Byte]] {
override def next(): Array[Byte] = {
val obj = _nextObj
if (hasNext) {
_nextObj = read()
}
obj
}
// Create stream to read data from process's stdout
val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))

val stdoutIterator = new Iterator[U] with PythonReadInterface {

// Create iterator for reading data blocks
val _dataIterator = dataReadBuilder(this.asInstanceOf[PythonReadInterface])

private def read(): Array[Byte] = {
def safeRead[T](block: => T): T = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}

try {
stream.readInt() match {
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case 0 => Array.empty[Byte]
case SpecialLengths.TIMING_DATA =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
val finishTime = stream.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj, StandardCharsets.UTF_8),
writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
// read some accumulator updates:
val numAccumulatorUpdates = stream.readInt()
(1 to numAccumulatorUpdates).foreach { _ =>
val updateLen = stream.readInt()
val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator.add(update)
}
// Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
if (reuse_worker) {
env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
released = true
}
}
null
}
block
} catch {

case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))

case e: Exception if env.isStopped =>
logDebug("Exception thrown after context is stopped", e)
null // exit silently
throw new RuntimeException("TODO: exit silently")// exit silently

case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
Expand All @@ -240,21 +200,120 @@ private[spark] class PythonRunner(
}
}

var _nextObj = read()
override def next(): U = {
safeRead {
_dataIterator.next()
}
}

override def hasNext: Boolean = {
safeRead {
_dataIterator.hasNext
}
}

override def getDataStream: DataInputStream = dataIn

override def readLengthFromPython(): Int = {
val length = dataIn.readInt()
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
// Signals that an exception has been thrown in python
val exLength = dataIn.readInt()
val obj = new Array[Byte](exLength)
dataIn.readFully(obj)
throw new PythonException(new String(obj, StandardCharsets.UTF_8), writerThread.exception.orNull)
}
length
}

override def readFooter(): Unit = {
// Timing data from worker
//readLengthFromPython() // == SpecialLengths.TIMING_DATA
val bootTime = dataIn.readLong()
val initTime = dataIn.readLong()
val finishTime = dataIn.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = dataIn.readLong()
val diskBytesSpilled = dataIn.readLong()
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)

// We've finished the data section of the output, but we can still
// read some accumulator updates:
readLengthFromPython() // == SpecialLengths.END_OF_DATA_SECTION
val numAccumulatorUpdates = readLengthFromPython()
(1 to numAccumulatorUpdates).foreach { _ =>
val updateLen = dataIn.readInt()
val update = new Array[Byte](updateLen)
dataIn.readFully(update)
accumulator.add(update)
}

override def hasNext: Boolean = _nextObj != null
// Check whether the worker is ready to be re-used.
if (readLengthFromPython() == SpecialLengths.END_OF_STREAM) {
if (reuse_worker) {
env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
released = true
}
}
// else == SpecialLengths.END_OF_DATA_SECTION to not reuse worker
}
}

new InterruptibleIterator(context, stdoutIterator)
}

def compute(
inputIterator: Iterator[_],
partitionIndex: Int,
context: TaskContext): Iterator[Array[Byte]] = {

val dataWriteBlock = (out: DataOutputStream) => {
PythonRDD.writeIteratorToStream(inputIterator, out)
}

val dataReadBuilder = (in: PythonReadInterface) => {
new Iterator[Array[Byte]] {
var _lastLength: Int = _

override def hasNext: Boolean = {
_lastLength = in.readLengthFromPython()
val result = _lastLength >= 0
if (!result) {
in.readFooter()
}
result
}

override def next(): Array[Byte] = {
_lastLength match {
case l if l > 0 =>
val obj = new Array[Byte](_lastLength)
in.getDataStream.readFully(obj)
obj
case 0 =>
Array.empty[Byte]
}
}
}
}

process(dataWriteBlock, dataReadBuilder, partitionIndex, context)
}

/**
* The thread responsible for writing the data from the PythonRDD's parent iterator to the
* Python process.
*/
class WriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[_],
dataWriteBlock: DataOutputStream => Unit,
partitionIndex: Int,
context: TaskContext)
extends Thread(s"stdout writer for $pythonExec") {
Expand Down Expand Up @@ -340,7 +399,7 @@ private[spark] class PythonRunner(
}

// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataWriteBlock(dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
Expand Down Expand Up @@ -701,6 +760,13 @@ private[spark] object PythonRDD extends Logging {
* The thread will terminate after all the data are sent or any exceptions happen.
*/
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
serveToStream(threadName) { out =>
writeIteratorToStream(items, out)
}
}

// TODO: scaladoc
def serveToStream(threadName: String)(dataWriteBlock: DataOutputStream => Unit): Int = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 3 seconds
serverSocket.setSoTimeout(3000)
Expand All @@ -712,13 +778,13 @@ private[spark] object PythonRDD extends Logging {
val sock = serverSocket.accept()
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
Utils.tryWithSafeFinally {
writeIteratorToStream(items, out)
dataWriteBlock(out)
} {
out.close()
}
} catch {
case NonFatal(e) =>
logError(s"Error while sending iterator", e)
logError(s"Error while writing to stream", e)
} finally {
serverSocket.close()
}
Expand Down
50 changes: 50 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,56 @@ def __repr__(self):
return "ArrowSerializer"


class ArrowStreamSerializer(Serializer):

def __init__(self, load_to_single_batch=True):
self._load_to_single = load_to_single_batch

def dump_stream(self, iterator, stream):
import pyarrow as pa
write_int(1, stream) # signal start of data block
writer = None
for batch in iterator:
if writer is None:
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
if writer is not None:
writer.close()

def load_stream(self, stream):
import pyarrow as pa
reader = pa.RecordBatchStreamReader(stream)
if self._load_to_single:
return reader.read_all()
else:
return iter(reader)

def __repr__(self):
return "ArrowStreamSerializer"


class ArrowPandasSerializer(ArrowStreamSerializer):

def __init__(self):
super(ArrowPandasSerializer, self).__init__(load_to_single_batch=True)

# dumps a Pandas Series to stream
def dump_stream(self, iterator, stream):
import pyarrow as pa
# TODO: iterator could be a tuple
arr = pa.Array.from_pandas(iterator)
batch = pa.RecordBatch.from_arrays([arr], ["_0"])
super(ArrowPandasSerializer, self).dump_stream([batch], stream)

# loads stream to a list of Pandas Series
def load_stream(self, stream):
table = super(ArrowPandasSerializer, self).load_stream(stream)
return [c.to_pandas() for c in table.itercolumns()]

def __repr__(self):
return "ArrowPandasSerializer"


class BatchedSerializer(Serializer):

"""
Expand Down
12 changes: 9 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer, \
ArrowSerializer
ArrowSerializer, ArrowPandasSerializer
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -118,8 +118,14 @@ def read_udfs(pickleSer, infile):
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
mapper = eval(mapper_str, udfs)

func = lambda _, it: map(mapper, it)
ser = BatchedSerializer(PickleSerializer(), 100)
# These lines enable UDF evaluation with Arrow
ser = ArrowPandasSerializer()
func = lambda _, series_list: mapper(series_list) # TODO: what if not vectorizable

# Uncomment out for default UDF evaluation
#func = lambda _, it: map(mapper, it)
#ser = BatchedSerializer(PickleSerializer(), 100)

# profiling is not supported for UDF
return func, None, ser, ser

Expand Down
Loading

0 comments on commit 45db636

Please sign in to comment.