Skip to content

Commit

Permalink
Support storing precision of decimal types in Schema class (#17176)
Browse files Browse the repository at this point in the history
In Spark, the `DecimalType` has a specific number of digits to represent the numbers. However, when creating a data Schema, only type and name of the column are stored, thus we lose that precision information. As such, it would be difficult to reconstruct the original decimal types from cudf's `Schema` instance.

This PR adds a `precision` member variable to the `Schema` class in cudf Java, allowing it to store the precision number of the original decimal column.

Partially contributes to NVIDIA/spark-rapids#11560.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #17176
  • Loading branch information
ttnghia authored Oct 29, 2024
1 parent 3775f7b commit ddfb284
Showing 1 changed file with 70 additions and 7 deletions.
77 changes: 70 additions & 7 deletions java/src/main/java/ai/rapids/cudf/Schema.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,52 @@ public class Schema {
public static final Schema INFERRED = new Schema();

private final DType topLevelType;

/**
* Default value for precision value, when it is not specified or the column type is not decimal.
*/
private static final int UNKNOWN_PRECISION = -1;

/**
* Store precision for the top level column, only applicable if the column is a decimal type.
* <p/>
* This variable is not designed to be used by any libcudf's APIs since libcudf does not support
* precisions for fixed point numbers.
* Instead, it is used only to pass down the precision values from Spark's DecimalType to the
* JNI level, where some JNI functions require these values to perform their operations.
*/
private final int topLevelPrecision;

private final List<String> childNames;
private final List<Schema> childSchemas;
private boolean flattened = false;
private String[] flattenedNames;
private DType[] flattenedTypes;
private int[] flattenedPrecisions;
private int[] flattenedCounts;

private Schema(DType topLevelType,
int topLevelPrecision,
List<String> childNames,
List<Schema> childSchemas) {
this.topLevelType = topLevelType;
this.topLevelPrecision = topLevelPrecision;
this.childNames = childNames;
this.childSchemas = childSchemas;
}

private Schema(DType topLevelType,
List<String> childNames,
List<Schema> childSchemas) {
this(topLevelType, UNKNOWN_PRECISION, childNames, childSchemas);
}

/**
* Inferred schema.
*/
private Schema() {
topLevelType = null;
topLevelPrecision = UNKNOWN_PRECISION;
childNames = null;
childSchemas = null;
}
Expand Down Expand Up @@ -104,14 +130,17 @@ private void flattenIfNeeded() {
if (flatLen == 0) {
flattenedNames = null;
flattenedTypes = null;
flattenedPrecisions = null;
flattenedCounts = null;
} else {
String[] names = new String[flatLen];
DType[] types = new DType[flatLen];
int[] precisions = new int[flatLen];
int[] counts = new int[flatLen];
collectFlattened(names, types, counts, 0);
collectFlattened(names, types, precisions, counts, 0);
flattenedNames = names;
flattenedTypes = types;
flattenedPrecisions = precisions;
flattenedCounts = counts;
}
flattened = true;
Expand All @@ -128,19 +157,20 @@ private int flattenedLength(int startingLength) {
return startingLength;
}

private int collectFlattened(String[] names, DType[] types, int[] counts, int offset) {
private int collectFlattened(String[] names, DType[] types, int[] precisions, int[] counts, int offset) {
if (childSchemas != null) {
for (int i = 0; i < childSchemas.size(); i++) {
Schema child = childSchemas.get(i);
names[offset] = childNames.get(i);
types[offset] = child.topLevelType;
precisions[offset] = child.topLevelPrecision;
if (child.childNames != null) {
counts[offset] = child.childNames.size();
} else {
counts[offset] = 0;
}
offset++;
offset = this.childSchemas.get(i).collectFlattened(names, types, counts, offset);
offset = this.childSchemas.get(i).collectFlattened(names, types, precisions, counts, offset);
}
}
return offset;
Expand Down Expand Up @@ -226,6 +256,22 @@ public int[] getFlattenedTypeScales() {
return ret;
}

/**
* Get decimal precisions of the columns' types flattened from all levels in schema by
* depth-first traversal.
* <p/>
* This is used to pass down the decimal precisions from Spark to only the JNI layer, where
* some JNI functions require precision values to perform their operations.
* Decimal precisions should not be consumed by any libcudf's APIs since libcudf does not
* support precisions for fixed point numbers.
*
* @return An array containing decimal precision of all columns in schema.
*/
public int[] getFlattenedDecimalPrecisions() {
flattenIfNeeded();
return flattenedPrecisions;
}

/**
* Get the types of the columns in schema flattened from all levels by depth-first traversal.
* @return An array containing types of all columns in schema.
Expand Down Expand Up @@ -307,11 +353,13 @@ public HostColumnVector.DataType asHostDataType() {

public static class Builder {
private final DType topLevelType;
private final int topLevelPrecision;
private final List<String> names;
private final List<Builder> types;

private Builder(DType topLevelType) {
private Builder(DType topLevelType, int topLevelPrecision) {
this.topLevelType = topLevelType;
this.topLevelPrecision = topLevelPrecision;
if (topLevelType == DType.STRUCT || topLevelType == DType.LIST) {
// There can be children
names = new ArrayList<>();
Expand All @@ -322,14 +370,19 @@ private Builder(DType topLevelType) {
}
}

private Builder(DType topLevelType) {
this(topLevelType, UNKNOWN_PRECISION);
}

/**
* Add a new column
* @param type the type of column to add
* @param name the name of the column to add (Ignored for list types)
* @param precision the decimal precision, only applicable for decimal types
* @return the builder for the new column. This should really only be used when the type
* passed in is a LIST or a STRUCT.
*/
public Builder addColumn(DType type, String name) {
public Builder addColumn(DType type, String name, int precision) {
if (names == null) {
throw new IllegalStateException("A column of type " + topLevelType +
" cannot have children");
Expand All @@ -340,21 +393,31 @@ public Builder addColumn(DType type, String name) {
if (names.contains(name)) {
throw new IllegalStateException("Cannot add duplicate names to a schema");
}
Builder ret = new Builder(type);
Builder ret = new Builder(type, precision);
types.add(ret);
names.add(name);
return ret;
}

public Builder addColumn(DType type, String name) {
return addColumn(type, name, UNKNOWN_PRECISION);
}

/**
* Adds a single column to the current schema. addColumn is preferred as it can be used
* to support nested types.
* @param type the type of the column.
* @param name the name of the column.
* @param precision the decimal precision, only applicable for decimal types.
* @return this for chaining.
*/
public Builder column(DType type, String name, int precision) {
addColumn(type, name, precision);
return this;
}

public Builder column(DType type, String name) {
addColumn(type, name);
addColumn(type, name, UNKNOWN_PRECISION);
return this;
}

Expand Down

0 comments on commit ddfb284

Please sign in to comment.