diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 39af92baa06c7..4122c71a22073 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -99,6 +99,15 @@ object CommandPythonFunctionType extends PythonFunctionType{ override val value object SqlUdfPythonFunctionType extends PythonFunctionType{ override val value = 1 } object PandasUdfPythonFunctionType extends PythonFunctionType{ override val value = 2 } +/** + * Interface that can be used when building an iterator to read data from Python + */ +private[spark] trait PythonReadInterface { + def getDataStream: DataInputStream + def readLengthFromPython(): Int + def readFooter(): Unit +} + /** * A helper class to run Python mapPartition/UDFs in Spark. * @@ -123,10 +132,11 @@ private[spark] class PythonRunner( // TODO: support accumulator in multiple UDF private val accumulator = funcs.head.funcs.head.accumulator - def compute( - inputIterator: Iterator[_], + def process[U]( + dataWriteBlock: DataOutputStream => Unit, + dataReadBuilder: PythonReadInterface => Iterator[U], partitionIndex: Int, - context: TaskContext): Iterator[Array[Byte]] = { + context: TaskContext): Iterator[U] = { val startTime = System.currentTimeMillis val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") @@ -139,7 +149,7 @@ private[spark] class PythonRunner( @volatile var released = false // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) + val writerThread = new WriterThread(env, worker, dataWriteBlock, partitionIndex, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() @@ -156,79 +166,29 @@ private[spark] class PythonRunner( writerThread.start() new MonitorThread(env, worker, context).start() - // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - val stdoutIterator = new Iterator[Array[Byte]] { - override def next(): Array[Byte] = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj - } + // Create stream to read data from process's stdout + val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + val stdoutIterator = new Iterator[U] with PythonReadInterface { + + // Create iterator for reading data blocks + val _dataIterator = dataReadBuilder(this.asInstanceOf[PythonReadInterface]) - private def read(): Array[Byte] = { + def safeRead[T](block: => T): T = { if (writerThread.exception.isDefined) { throw writerThread.exception.get } + try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - case 0 => Array.empty[Byte] - case SpecialLengths.TIMING_DATA => - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) - val memoryBytesSpilled = stream.readLong() - val diskBytesSpilled = stream.readLong() - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - throw new PythonException(new String(obj, StandardCharsets.UTF_8), - writerThread.exception.getOrElse(null)) - case SpecialLengths.END_OF_DATA_SECTION => - // We've finished the data section of the output, but we can still - // read some accumulator updates: - val numAccumulatorUpdates = stream.readInt() - (1 to numAccumulatorUpdates).foreach { _ => - val updateLen = stream.readInt() - val update = new Array[Byte](updateLen) - stream.readFully(update) - accumulator.add(update) - } - // Check whether the worker is ready to be re-used. - if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) - released = true - } - } - null - } + block } catch { - case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) case e: Exception if env.isStopped => logDebug("Exception thrown after context is stopped", e) - null // exit silently + throw new RuntimeException("TODO: exit silently")// exit silently case e: Exception if writerThread.exception.isDefined => logError("Python worker exited unexpectedly (crashed)", e) @@ -240,13 +200,112 @@ private[spark] class PythonRunner( } } - var _nextObj = read() + override def next(): U = { + safeRead { + _dataIterator.next() + } + } + + override def hasNext: Boolean = { + safeRead { + _dataIterator.hasNext + } + } + + override def getDataStream: DataInputStream = dataIn + + override def readLengthFromPython(): Int = { + val length = dataIn.readInt() + if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + // Signals that an exception has been thrown in python + val exLength = dataIn.readInt() + val obj = new Array[Byte](exLength) + dataIn.readFully(obj) + throw new PythonException(new String(obj, StandardCharsets.UTF_8), writerThread.exception.orNull) + } + length + } + + override def readFooter(): Unit = { + // Timing data from worker + //readLengthFromPython() // == SpecialLengths.TIMING_DATA + val bootTime = dataIn.readLong() + val initTime = dataIn.readLong() + val finishTime = dataIn.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, + init, finish)) + val memoryBytesSpilled = dataIn.readLong() + val diskBytesSpilled = dataIn.readLong() + context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + + // We've finished the data section of the output, but we can still + // read some accumulator updates: + readLengthFromPython() // == SpecialLengths.END_OF_DATA_SECTION + val numAccumulatorUpdates = readLengthFromPython() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = dataIn.readInt() + val update = new Array[Byte](updateLen) + dataIn.readFully(update) + accumulator.add(update) + } - override def hasNext: Boolean = _nextObj != null + // Check whether the worker is ready to be re-used. + if (readLengthFromPython() == SpecialLengths.END_OF_STREAM) { + if (reuse_worker) { + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + released = true + } + } + // else == SpecialLengths.END_OF_DATA_SECTION to not reuse worker + } } + new InterruptibleIterator(context, stdoutIterator) } + def compute( + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): Iterator[Array[Byte]] = { + + val dataWriteBlock = (out: DataOutputStream) => { + PythonRDD.writeIteratorToStream(inputIterator, out) + } + + val dataReadBuilder = (in: PythonReadInterface) => { + new Iterator[Array[Byte]] { + var _lastLength: Int = _ + + override def hasNext: Boolean = { + _lastLength = in.readLengthFromPython() + val result = _lastLength >= 0 + if (!result) { + in.readFooter() + } + result + } + + override def next(): Array[Byte] = { + _lastLength match { + case l if l > 0 => + val obj = new Array[Byte](_lastLength) + in.getDataStream.readFully(obj) + obj + case 0 => + Array.empty[Byte] + } + } + } + } + + process(dataWriteBlock, dataReadBuilder, partitionIndex, context) + } + /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the * Python process. @@ -254,7 +313,7 @@ private[spark] class PythonRunner( class WriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[_], + dataWriteBlock: DataOutputStream => Unit, partitionIndex: Int, context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { @@ -340,7 +399,7 @@ private[spark] class PythonRunner( } // Data values - PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataWriteBlock(dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() @@ -701,6 +760,13 @@ private[spark] object PythonRDD extends Logging { * The thread will terminate after all the data are sent or any exceptions happen. */ def serveIterator[T](items: Iterator[T], threadName: String): Int = { + serveToStream(threadName) { out => + writeIteratorToStream(items, out) + } + } + + // TODO: scaladoc + def serveToStream(threadName: String)(dataWriteBlock: DataOutputStream => Unit): Int = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) @@ -712,13 +778,13 @@ private[spark] object PythonRDD extends Logging { val sock = serverSocket.accept() val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) Utils.tryWithSafeFinally { - writeIteratorToStream(items, out) + dataWriteBlock(out) } { out.close() } } catch { case NonFatal(e) => - logError(s"Error while sending iterator", e) + logError(s"Error while writing to stream", e) } finally { serverSocket.close() } diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 709c138e9b894..b3e8acf3c7a3b 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -206,6 +206,56 @@ def __repr__(self): return "ArrowSerializer" +class ArrowStreamSerializer(Serializer): + + def __init__(self, load_to_single_batch=True): + self._load_to_single = load_to_single_batch + + def dump_stream(self, iterator, stream): + import pyarrow as pa + write_int(1, stream) # signal start of data block + writer = None + for batch in iterator: + if writer is None: + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + if writer is not None: + writer.close() + + def load_stream(self, stream): + import pyarrow as pa + reader = pa.RecordBatchStreamReader(stream) + if self._load_to_single: + return reader.read_all() + else: + return iter(reader) + + def __repr__(self): + return "ArrowStreamSerializer" + + +class ArrowPandasSerializer(ArrowStreamSerializer): + + def __init__(self): + super(ArrowPandasSerializer, self).__init__(load_to_single_batch=True) + + # dumps a Pandas Series to stream + def dump_stream(self, iterator, stream): + import pyarrow as pa + # TODO: iterator could be a tuple + arr = pa.Array.from_pandas(iterator) + batch = pa.RecordBatch.from_arrays([arr], ["_0"]) + super(ArrowPandasSerializer, self).dump_stream([batch], stream) + + # loads stream to a list of Pandas Series + def load_stream(self, stream): + table = super(ArrowPandasSerializer, self).load_stream(stream) + return [c.to_pandas() for c in table.itercolumns()] + + def __repr__(self): + return "ArrowPandasSerializer" + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index ac609d9438017..f885c79bacc70 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,7 +31,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer, \ - ArrowSerializer + ArrowSerializer, ArrowPandasSerializer from pyspark import shuffle pickleSer = PickleSerializer() @@ -118,8 +118,14 @@ def read_udfs(pickleSer, infile): mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) mapper = eval(mapper_str, udfs) - func = lambda _, it: map(mapper, it) - ser = BatchedSerializer(PickleSerializer(), 100) + # These lines enable UDF evaluation with Arrow + ser = ArrowPandasSerializer() + func = lambda _, series_list: mapper(series_list) # TODO: what if not vectorizable + + # Uncomment out for default UDF evaluation + #func = lambda _, it: map(mapper, it) + #ser = BatchedSerializer(PickleSerializer(), 100) + # profiling is not supported for UDF return func, None, ser, ser diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index dee2671ed7cad..95eedf90098fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -34,7 +34,7 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -262,6 +262,72 @@ private[sql] object ArrowConverters { reader.close() } } + + private[arrow] def writeRowsAsArrow( + rowIter: Iterator[InternalRow], + schema: StructType, + out: DataOutputStream): Unit = { + val allocator = new RootAllocator(Long.MaxValue) + val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val loader = new VectorLoader(root) + val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out)) + + val batch = internalRowIterToArrowBatch(rowIter, schema, allocator) + + // TODO: catch exceptions + loader.load(batch) + writer.writeBatch() + writer.end() + + batch.close() + root.close() + allocator.close() + } + + private[arrow] def readArrowAsRows(in: DataInputStream): Iterator[InternalRow] = { + new Iterator[InternalRow] { + val _allocator = new RootAllocator(Long.MaxValue) + private val _reader = new ArrowStreamReader(Channels.newChannel(in), _allocator) + private val _root = _reader.getVectorSchemaRoot + private var _index = 0 + val mutableRow = new GenericInternalRow(1) + + _reader.loadNextBatch() + + override def hasNext: Boolean = _index < _root.getRowCount + + override def next(): InternalRow = { + val fieldVecs = _root.getFieldVectors + + if (fieldVecs.size() == 1) { + mutableRow(0) = fieldVecs.get(0).getAccessor.getObject(_index) + _index += 1 + if (_index >= _root.getRowCount) { + _index = 0 + _reader.loadNextBatch() + } + mutableRow + } else { + val fields = _root.getFieldVectors.asScala + + val genericRowData = fields.map { field => + val obj: Any = field.getAccessor.getObject(_index) + obj + }.toArray + + _index += 1 + + if (_index >= _root.getRowCount) { + _index = 0 + _reader.loadNextBatch() + } + + new GenericInternalRow(genericRowData) + } + } + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowEvalPythonExec.scala new file mode 100644 index 0000000000000..5ac134fdd1d10 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowEvalPythonExec.scala @@ -0,0 +1,143 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.arrow + +import java.io.{DataOutputStream, File} + +import org.apache.spark.api.python.{ChainedPythonFunctions, PandasUdfPythonFunctionType, PythonReadInterface, PythonRunner} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.python.{HybridRowQueue, PythonUDF} + +//import org.apache.spark.sql.ArrowConverters +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils +import org.apache.spark.{SparkEnv, TaskContext} + +import scala.collection.mutable.ArrayBuffer + + +/** + * A physical plan that evaluates a [[PythonUDF]], + */ +case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) + extends SparkPlan { + + def children: Seq[SparkPlan] = child :: Nil + + override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + + inputRDD.mapPartitions { iter => + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + TaskContext.get().addTaskCompletionListener({ ctx => + queue.close() + }) + + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + + // flatten all the arguments + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + val projection = newMutableProjection(allInputs, child.output) + val schema = StructType(dataTypes.map(dt => StructField("", dt))) + + // enable memo iff we serialize the row with schema (schema and class should be memorized) + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val projectedRowIter = iter.map { inputRow => + queue.add(inputRow.asInstanceOf[UnsafeRow]) + projection(inputRow) + } + + val dataWriteBlock = (out: DataOutputStream) => { + ArrowConverters.writeRowsAsArrow(projectedRowIter, schema, out) + } + + val dataReadBuilder = (in: PythonReadInterface) => { + new Iterator[InternalRow] { + + // Check for initial error + in.readLengthFromPython() + + val iter = ArrowConverters.readArrowAsRows(in.getDataStream) + + override def hasNext: Boolean = { + val result = iter.hasNext + if (!result) { + in.readLengthFromPython() // == SpecialLengths.TIMING_DATA, marks end of data + in.readFooter() + } + result + } + + override def next(): InternalRow = { + iter.next() + } + } + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, PandasUdfPythonFunctionType, argOffsets) + .process(dataWriteBlock, dataReadBuilder, context.partitionId(), context) + + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) + + outputIterator.map { outputRow => + resultProj(joined(queue.remove(), outputRow)) + } + } + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 69b4b7bb07de6..dec46dfdc09bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{FilterExec, SparkPlan} +import org.apache.spark.sql.execution.arrow.ArrowEvalPythonExec /** @@ -138,7 +139,13 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } - val evaluation = BatchEvalPythonExec(validUdfs, child.output ++ resultAttrs, child) + + // This line enables UDF evaluation with Arrow + val evaluation = ArrowEvalPythonExec(validUdfs, child.output ++ resultAttrs, child) + + // Uncomment for default UDF evaluation + //val evaluation = BatchEvalPythonExec(validUdfs, child.output ++ resultAttrs, child) + attributeMap ++= validUdfs.zip(resultAttrs) evaluation } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala index cd1e77f524afd..723b54e2e2a67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -163,7 +163,7 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu * HybridRowQueue could be safely appended in one thread, and pulled in another thread in the same * time. */ -private[python] case class HybridRowQueue( +private[execution] case class HybridRowQueue( memManager: TaskMemoryManager, tempDir: File, numFields: Int)