Skip to content

Commit

Permalink
Introduce kudo reader
Browse files Browse the repository at this point in the history
Signed-off-by: liurenjie1024 <[email protected]>
  • Loading branch information
liurenjie1024 committed Nov 7, 2024
1 parent 2aa3348 commit 3164564
Show file tree
Hide file tree
Showing 12 changed files with 1,286 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/main/java/com/nvidia/spark/rapids/jni/Arms.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
/**
* This class contains utility methods for automatic resource management.
*/
class Arms {
public class Arms {
/**
* This method close the resource if an exception is thrown while executing the function.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/nvidia/spark/rapids/jni/Pair.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
/**
* A utility class for holding a pair of values.
*/
class Pair<K, V> {
public class Pair<K, V> {
private final K left;
private final V right;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni.kudo;

import java.util.OptionalLong;

import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative;

/**
* This class is used to store the offsets of the buffer of a column in the serialized data.
*/
class ColumnOffsetInfo {
private static final long INVALID_OFFSET = -1L;
private final long validity;
private final long offset;
private final long data;
private final long dataLen;

public ColumnOffsetInfo(long validity, long offset, long data, long dataLen) {
ensure(dataLen >= 0, () -> "dataLen must be non-negative, but was " + dataLen);
this.validity = validity;
this.offset = offset;
this.data = data;
this.dataLen = dataLen;
}

public OptionalLong getValidity() {
return (validity == INVALID_OFFSET) ? OptionalLong.empty() : OptionalLong.of(validity);
}

public OptionalLong getOffset() {
return (offset == INVALID_OFFSET) ? OptionalLong.empty() : OptionalLong.of(offset);
}

public OptionalLong getData() {
return (data == INVALID_OFFSET) ? OptionalLong.empty() : OptionalLong.of(data);
}

public long getDataLen() {
return dataLen;
}

@Override
public String toString() {
return "ColumnOffsets{" +
"validity=" + validity +
", offset=" + offset +
", data=" + data +
", dataLen=" + dataLen +
'}';
}
}
74 changes: 74 additions & 0 deletions src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni.kudo;

import ai.rapids.cudf.ColumnView;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.DeviceMemoryBuffer;

import static java.lang.Math.toIntExact;

class ColumnViewInfo {
private final DType dtype;
private final ColumnOffsetInfo offsetInfo;
private final long nullCount;
private final long rowCount;

public ColumnViewInfo(DType dtype, ColumnOffsetInfo offsetInfo,
long nullCount, long rowCount) {
this.dtype = dtype;
this.offsetInfo = offsetInfo;
this.nullCount = nullCount;
this.rowCount = rowCount;
}

public long buildColumnView(DeviceMemoryBuffer buffer, long[] childrenView) {
long bufferAddress = buffer.getAddress();

long dataAddress = 0;
if (offsetInfo.getData().isPresent()) {
dataAddress = buffer.getAddress() + offsetInfo.getData().getAsLong();
}

long validityAddress = 0;
if (offsetInfo.getValidity().isPresent()) {
validityAddress = offsetInfo.getValidity().getAsLong() + bufferAddress;
}

long offsetsAddress = 0;
if (offsetInfo.getOffset().isPresent()) {
offsetsAddress = offsetInfo.getOffset().getAsLong() + bufferAddress;
}

return ColumnView.makeCudfColumnView(
dtype.getTypeId().getNativeId(), dtype.getScale(),
dataAddress, offsetInfo.getDataLen(),
offsetsAddress, validityAddress,
toIntExact(nullCount), toIntExact(rowCount),
childrenView);
}

@Override
public String toString() {
return "ColumnViewInfo{" +
"dtype=" + dtype +
", offsetInfo=" + offsetInfo +
", nullCount=" + nullCount +
", rowCount=" + rowCount +
'}';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni.kudo;

import ai.rapids.cudf.*;
import com.nvidia.spark.rapids.jni.Arms;
import com.nvidia.spark.rapids.jni.schema.Visitors;

import java.util.List;

import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
import static java.util.Objects.requireNonNull;

/**
* The result of merging several kudo tables into one contiguous table on the host.
*/
public class KudoHostMergeResult implements AutoCloseable {
private final Schema schema;
private final List<ColumnViewInfo> columnInfoList;
private final HostMemoryBuffer hostBuf;

KudoHostMergeResult(Schema schema, HostMemoryBuffer hostBuf, List<ColumnViewInfo> columnInfoList) {
requireNonNull(schema, "schema is null");
requireNonNull(columnInfoList, "columnOffsets is null");
ensure(schema.getFlattenedColumnNames().length == columnInfoList.size(), () ->
"Column offsets size does not match flattened schema size, column offsets size: " + columnInfoList.size() +
", flattened schema size: " + schema.getFlattenedColumnNames().length);
this.schema = schema;
this.columnInfoList = columnInfoList;
this.hostBuf = requireNonNull(hostBuf, "hostBuf is null") ;
}

@Override
public void close() throws Exception {
if (hostBuf != null) {
hostBuf.close();
}
}

public ContiguousTable toContiguousTable() {
return Arms.closeIfException(DeviceMemoryBuffer.allocate(hostBuf.getLength()),
deviceMemBuf -> {
if (hostBuf.getLength() > 0) {
deviceMemBuf.copyFromHostBuffer(hostBuf);
}

TableBuilder builder = new TableBuilder(columnInfoList, deviceMemBuf);
Table t = Visitors.visitSchema(schema, builder);

return new ContiguousTable(t, deviceMemBuf);
});
}

@Override
public String toString() {
return "HostMergeResult{" +
"columnOffsets=" + columnInfoList +
", hostBuf length =" + hostBuf.getLength() +
'}';
}
}
110 changes: 106 additions & 4 deletions src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@
package com.nvidia.spark.rapids.jni.kudo;

import ai.rapids.cudf.*;
import com.nvidia.spark.rapids.jni.Arms;
import com.nvidia.spark.rapids.jni.Pair;
import com.nvidia.spark.rapids.jni.schema.Visitors;

import java.io.*;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.LongConsumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
Expand Down Expand Up @@ -151,10 +158,12 @@ public class KudoSerializer {
Arrays.fill(PADDING, (byte) 0);
}

private final Schema schema;
private final int flattenedColumnCount;

public KudoSerializer(Schema schema) {
requireNonNull(schema, "schema is null");
this.schema = schema;
this.flattenedColumnCount = schema.getFlattenedColumnNames().length;
}

Expand Down Expand Up @@ -237,6 +246,81 @@ public static long writeRowCountToStream(OutputStream out, int numRows) {
}
}

/**
* Read a kudo table from an input stream.
* @param in input stream
* @return the kudo table, or empty if the input stream is empty.
* @throws IOException if an I/O error occurs
*/
public Optional<KudoTable> readOneTable(InputStream in) throws IOException {
Objects.requireNonNull(in, "Input stream must not be null");

DataInputStream din = readerFrom(in);
return KudoTableHeader.readFrom(din).map(header -> {
if (header.getNumRows() <= 0) {
throw new IllegalArgumentException("Number of rows must be > 0, but was " + header.getNumRows());
}

// Header only
if (header.getNumColumns() == 0) {
return new KudoTable(header, null);
}

return Arms.closeIfException(HostMemoryBuffer.allocate(header.getTotalDataLen(), false), buffer -> {
try {
buffer.copyFromStream(0, din, header.getTotalDataLen());
return new KudoTable(header, buffer);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
});
}

/**
* Merge a list of kudo tables into a table on host memory.
* <br/>
* The caller should ensure that the {@link KudoSerializer} used to generate kudo tables have same schema as current
* {@link KudoSerializer}, otherwise behavior is undefined.
*
* @param kudoTables list of kudo tables. This method doesn't take ownership of the input tables, and caller should
* take care of closing them after calling this method.
* @return the merged table, and metrics during merge.
*/
public Pair<KudoHostMergeResult, MergeMetrics> mergeOnHost(List<KudoTable> kudoTables) {
MergeMetrics.Builder metricsBuilder = MergeMetrics.builder();

MergedInfoCalc mergedInfoCalc = withTime(() -> MergedInfoCalc.calc(schema, kudoTables),
metricsBuilder::calcHeaderTime);
KudoHostMergeResult result = withTime(() -> KudoTableMerger.merge(schema, mergedInfoCalc),
metricsBuilder::mergeIntoHostBufferTime);
return Pair.of(result, metricsBuilder.build());

}

/**
* Merge a list of kudo tables into a contiguous table.
* <br/>
* The caller should ensure that the {@link KudoSerializer} used to generate kudo tables have same schema as current
* {@link KudoSerializer}, otherwise behavior is undefined.
*
* @param kudoTables list of kudo tables. This method doesn't take ownership of the input tables, and caller should
* take care of closing them after calling this method.
* @return the merged table, and metrics during merge.
*/
public Pair<ContiguousTable, MergeMetrics> mergeToTable(List<KudoTable> kudoTables) {
Pair<KudoHostMergeResult, MergeMetrics> result = mergeOnHost(kudoTables);
MergeMetrics.Builder builder = MergeMetrics.builder(result.getRight());
try (KudoHostMergeResult children = result.getLeft()) {
ContiguousTable table = withTime(() -> children.toContiguousTable(schema),
builder::convertIntoContiguousTableTime);

return Pair.of(table, builder.build());
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffset, int numRows) throws Exception {
KudoTableHeaderCalc headerCalc = new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount);
Visitors.visitColumns(columns, headerCalc);
Expand Down Expand Up @@ -285,9 +369,27 @@ static long padFor64byteAlignment(long orig) {
return ((orig + 63) / 64) * 64;
}

static int safeLongToNonNegativeInt(long value) {
ensure(value >= 0, () -> "Expected non negative value, but was " + value);
ensure(value <= Integer.MAX_VALUE, () -> "Value is too large to fit in an int");
return (int) value;
private static DataInputStream readerFrom(InputStream in) {
if (in instanceof DataInputStream) {
return (DataInputStream)in;
}
return new DataInputStream(in);
}

static <T> T withTime(Supplier<T> task, LongConsumer timeConsumer) {
long now = System.nanoTime();
T ret = task.get();
timeConsumer.accept(System.nanoTime() - now);
return ret;
}

/**
* This method returns the length in bytes needed to represent X number of rows
* e.g. getValidityLengthInBytes(5) => 1 byte
* getValidityLengthInBytes(7) => 1 byte
* getValidityLengthInBytes(14) => 2 bytes
*/
static long getValidityLengthInBytes(long rows) {
return (rows + 7) / 8;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative;
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.safeLongToNonNegativeInt;
import static java.util.Objects.requireNonNull;

/**
Expand Down
Loading

0 comments on commit 3164564

Please sign in to comment.