Skip to content

Commit

Permalink
[SPARK-26966][ML] Update to JPMML 1.4.8
Browse files Browse the repository at this point in the history
JPMML apparently only supports Java 9 in 1.4.2+. We are seeing text failures from JPMML relating to JAXB when running on Java 11. It's shaded and not a big change, so should be safe.

Existing tests.

Closes apache#23868 from srowen/SPARK-26966.

Authored-by: Sean Owen <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
srowen authored and sumwale committed Jun 10, 2022
1 parent b896a6e commit d2ab68a
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 24 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ allprojects {
jlineVersion = '2.14.6'
xbeanAsm5Version = '4.5'
breezeVersion = '0.13.2'
pmmlVersion = '1.2.17'
pmmlVersion = '1.4.15'
classutilVersion = '1.4.0'
scoptVersion = '3.7.1'
mesosVersion = '1.0.4'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package org.apache.spark.mllib.pmml.export

import scala.{Array => SArray}

import org.dmg.pmml._
import org.dmg.pmml.{DataDictionary, DataField, DataType, FieldName, MiningField,
MiningFunction, MiningSchema, OpType}
import org.dmg.pmml.regression.{NumericPredictor, RegressionModel, RegressionTable}

import org.apache.spark.mllib.regression.GeneralizedLinearModel

Expand All @@ -29,7 +31,7 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel
private[mllib] class BinaryClassificationPMMLModelExport(
model: GeneralizedLinearModel,
description: String,
normalizationMethod: RegressionNormalizationMethodType,
normalizationMethod: RegressionModel.NormalizationMethod,
threshold: Double)
extends PMMLModelExport {

Expand All @@ -47,7 +49,7 @@ private[mllib] class BinaryClassificationPMMLModelExport(
val miningSchema = new MiningSchema
val regressionTableYES = new RegressionTable(model.intercept).setTargetCategory("1")
var interceptNO = threshold
if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) {
if (RegressionModel.NormalizationMethod.LOGIT == normalizationMethod) {
if (threshold <= 0) {
interceptNO = Double.MinValue
} else if (threshold >= 1) {
Expand All @@ -58,7 +60,7 @@ private[mllib] class BinaryClassificationPMMLModelExport(
}
val regressionTableNO = new RegressionTable(interceptNO).setTargetCategory("0")
val regressionModel = new RegressionModel()
.setFunctionName(MiningFunctionType.CLASSIFICATION)
.setMiningFunction(MiningFunction.CLASSIFICATION)
.setMiningSchema(miningSchema)
.setModelName(description)
.setNormalizationMethod(normalizationMethod)
Expand All @@ -69,7 +71,7 @@ private[mllib] class BinaryClassificationPMMLModelExport(
dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.addMiningFields(new MiningField(fields(i))
.setUsageType(FieldUsageType.ACTIVE))
.setUsageType(MiningField.UsageType.ACTIVE))
regressionTableYES.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

Expand All @@ -79,7 +81,7 @@ private[mllib] class BinaryClassificationPMMLModelExport(
.addDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING))
miningSchema
.addMiningFields(new MiningField(targetField)
.setUsageType(FieldUsageType.TARGET))
.setUsageType(MiningField.UsageType.TARGET))

dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package org.apache.spark.mllib.pmml.export

import scala.{Array => SArray}

import org.dmg.pmml._
import org.dmg.pmml.{DataDictionary, DataField, DataType, FieldName, MiningField,
MiningFunction, MiningSchema, OpType}
import org.dmg.pmml.regression.{NumericPredictor, RegressionModel, RegressionTable}

import org.apache.spark.mllib.regression.GeneralizedLinearModel

Expand All @@ -45,7 +47,7 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
val miningSchema = new MiningSchema
val regressionTable = new RegressionTable(model.intercept)
val regressionModel = new RegressionModel()
.setFunctionName(MiningFunctionType.REGRESSION)
.setMiningFunction(MiningFunction.REGRESSION)
.setMiningSchema(miningSchema)
.setModelName(description)
.addRegressionTables(regressionTable)
Expand All @@ -55,7 +57,7 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.addMiningFields(new MiningField(fields(i))
.setUsageType(FieldUsageType.ACTIVE))
.setUsageType(MiningField.UsageType.ACTIVE))
regressionTable.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

Expand All @@ -64,7 +66,7 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
dataDictionary.addDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.addMiningFields(new MiningField(targetField)
.setUsageType(FieldUsageType.TARGET))
.setUsageType(MiningField.UsageType.TARGET))

dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package org.apache.spark.mllib.pmml.export

import scala.{Array => SArray}

import org.dmg.pmml._
import org.dmg.pmml.{Array, CompareFunction, ComparisonMeasure, DataDictionary, DataField, DataType,
FieldName, MiningField, MiningFunction, MiningSchema, OpType, SquaredEuclidean}
import org.dmg.pmml.clustering.{Cluster, ClusteringField, ClusteringModel}

import org.apache.spark.mllib.clustering.KMeansModel

Expand Down Expand Up @@ -48,7 +50,7 @@ private[mllib] class KMeansPMMLModelExport(model: KMeansModel) extends PMMLModel
.setModelName("k-means")
.setMiningSchema(miningSchema)
.setComparisonMeasure(comparisonMeasure)
.setFunctionName(MiningFunctionType.CLUSTERING)
.setMiningFunction(MiningFunction.CLUSTERING)
.setModelClass(ClusteringModel.ModelClass.CENTER_BASED)
.setNumberOfClusters(model.clusterCenters.length)

Expand All @@ -57,9 +59,9 @@ private[mllib] class KMeansPMMLModelExport(model: KMeansModel) extends PMMLModel
dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.addMiningFields(new MiningField(fields(i))
.setUsageType(FieldUsageType.ACTIVE))
.setUsageType(MiningField.UsageType.ACTIVE))
clusteringModel.addClusteringFields(
new ClusteringField(fields(i)).setCompareFunction(CompareFunctionType.ABS_DIFF))
new ClusteringField(fields(i)).setCompareFunction(CompareFunction.ABS_DIFF))
}

dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.mllib.pmml.export

import org.dmg.pmml.RegressionNormalizationMethodType
import org.dmg.pmml.regression.RegressionModel

import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.classification.SVMModel
Expand All @@ -44,12 +44,12 @@ private[mllib] object PMMLModelExportFactory {
new GeneralizedLinearPMMLModelExport(lasso, "lasso regression")
case svm: SVMModel =>
new BinaryClassificationPMMLModelExport(
svm, "linear SVM", RegressionNormalizationMethodType.NONE,
svm, "linear SVM", RegressionModel.NormalizationMethod.NONE,
svm.getThreshold.getOrElse(0.0))
case logistic: LogisticRegressionModel =>
if (logistic.numClasses == 2) {
new BinaryClassificationPMMLModelExport(
logistic, "logistic regression", RegressionNormalizationMethodType.LOGIT,
logistic, "logistic regression", RegressionModel.NormalizationMethod.LOGIT,
logistic.getThreshold.getOrElse(0.5))
} else {
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.mllib.pmml.export

import org.dmg.pmml.RegressionModel
import org.dmg.pmml.RegressionNormalizationMethodType
import org.dmg.pmml.regression.RegressionModel

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.LogisticRegressionModel
Expand Down Expand Up @@ -51,7 +50,8 @@ class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite {
assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
// ensure logistic regression has normalization method set to LOGIT
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT)
assert(pmmlRegressionModel.getNormalizationMethod() ===
RegressionModel.NormalizationMethod.LOGIT)
}

test("linear SVM PMML export") {
Expand All @@ -78,7 +78,8 @@ class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite {
assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
// ensure linear SVM has normalization method set to NONE
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE)
assert(pmmlRegressionModel.getNormalizationMethod() ===
RegressionModel.NormalizationMethod.NONE)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.mllib.pmml.export

import org.dmg.pmml.RegressionModel
import org.dmg.pmml.regression.RegressionModel

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.mllib.pmml.export

import org.dmg.pmml.ClusteringModel
import org.dmg.pmml.clustering.ClusteringModel

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.clustering.KMeansModel
Expand Down
12 changes: 12 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,18 @@
<version>14.0.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-model</artifactId>
<version>1.4.8</version>
<scope>provided</scope>
<exclusions>
<exclusion>
<groupId>org.jpmml</groupId>
<artifactId>pmml-agent</artifactId>
</exclusion>
</exclusions>
</dependency>
<!-- End of shaded deps -->
<dependency>
<groupId>org.apache.commons</groupId>
Expand Down

0 comments on commit d2ab68a

Please sign in to comment.