diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b79615b..24883be 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,7 +6,7 @@ CHANGELOG ======== * pyspark: SageMakerModel: Fix bugs in creating model from training job, s3 file and endpoint - +* spark/pyspark: XGBoostSageMakerEstimator: Fix seed hyperparameter to use correct type (Int) 1.0.4 ===== diff --git a/sagemaker-pyspark-sdk/src/sagemaker_pyspark/algorithms/XGBoostSageMakerEstimator.py b/sagemaker-pyspark-sdk/src/sagemaker_pyspark/algorithms/XGBoostSageMakerEstimator.py index ce56700..92c183f 100644 --- a/sagemaker-pyspark-sdk/src/sagemaker_pyspark/algorithms/XGBoostSageMakerEstimator.py +++ b/sagemaker-pyspark-sdk/src/sagemaker_pyspark/algorithms/XGBoostSageMakerEstimator.py @@ -331,7 +331,7 @@ class XGBoostSageMakerEstimator(SageMakerEstimatorBase): seed = Param( Params._dummy(), "seed", "Random number seed", - typeConverter=TypeConverters.toFloat) + typeConverter=TypeConverters.toInt) num_round = Param( Params._dummy(), "num_round", diff --git a/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimator.scala b/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimator.scala index 4271b65..84dd1ab 100644 --- a/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimator.scala +++ b/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimator.scala @@ -379,8 +379,8 @@ private[algorithms] trait XGBoostParams extends Params { /** Random number seed. * Default = 0 */ - val seed: DoubleParam = new DoubleParam(this, "seed", "Random number seed.") - def getSeed: Double = $(seed) + val seed: IntParam = new IntParam(this, "seed", "Random number seed.") + def getSeed: Int = $(seed) /** * Number of rounds for gradient boosting. Must be >= 1. Required. @@ -614,7 +614,7 @@ class XGBoostSageMakerEstimator( def setEvalMetric(value: String) : this.type = set(evalMetric, value) - def setSeed(value: Double) : this.type = set(seed, value) + def setSeed(value: Int) : this.type = set(seed, value) def setNumRound(value: Int) : this.type = set(numRound, value) diff --git a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimatorTests.scala b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimatorTests.scala index cbfb55d..8218cb0 100644 --- a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimatorTests.scala +++ b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimatorTests.scala @@ -301,7 +301,7 @@ class XGBoostSageMakerEstimatorTests extends FlatSpec with Matchers with Mockito it should "setSeed" in { - val seed = 10.0 + val seed = 10 estimator.setSeed(seed) assert(seed == estimator.getSeed) } @@ -381,7 +381,7 @@ class XGBoostSageMakerEstimatorTests extends FlatSpec with Matchers with Mockito "objective" -> "reg:logistic", "base_score" -> "0.5", "eval_metric" -> "mae", - "seed" -> "0.0" + "seed" -> "0" ) assert(hyperParamMap.asJava == estimator.makeHyperParameters())