Skip to content

Commit

Permalink
Fix seed parameter for XGBoost (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChoiByungWook authored Apr 23, 2018
1 parent 4dd7f0e commit ebe6733
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ CHANGELOG
========

* pyspark: SageMakerModel: Fix bugs in creating model from training job, s3 file and endpoint

* spark/pyspark: XGBoostSageMakerEstimator: Fix seed hyperparameter to use correct type (Int)

1.0.4
=====
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class XGBoostSageMakerEstimator(SageMakerEstimatorBase):
seed = Param(
Params._dummy(), "seed",
"Random number seed",
typeConverter=TypeConverters.toFloat)
typeConverter=TypeConverters.toInt)

num_round = Param(
Params._dummy(), "num_round",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,8 @@ private[algorithms] trait XGBoostParams extends Params {
/** Random number seed.
* Default = 0
*/
val seed: DoubleParam = new DoubleParam(this, "seed", "Random number seed.")
def getSeed: Double = $(seed)
val seed: IntParam = new IntParam(this, "seed", "Random number seed.")
def getSeed: Int = $(seed)

/**
* Number of rounds for gradient boosting. Must be >= 1. Required.
Expand Down Expand Up @@ -614,7 +614,7 @@ class XGBoostSageMakerEstimator(

def setEvalMetric(value: String) : this.type = set(evalMetric, value)

def setSeed(value: Double) : this.type = set(seed, value)
def setSeed(value: Int) : this.type = set(seed, value)

def setNumRound(value: Int) : this.type = set(numRound, value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ class XGBoostSageMakerEstimatorTests extends FlatSpec with Matchers with Mockito


it should "setSeed" in {
val seed = 10.0
val seed = 10
estimator.setSeed(seed)
assert(seed == estimator.getSeed)
}
Expand Down Expand Up @@ -381,7 +381,7 @@ class XGBoostSageMakerEstimatorTests extends FlatSpec with Matchers with Mockito
"objective" -> "reg:logistic",
"base_score" -> "0.5",
"eval_metric" -> "mae",
"seed" -> "0.0"
"seed" -> "0"
)

assert(hyperParamMap.asJava == estimator.makeHyperParameters())
Expand Down

0 comments on commit ebe6733

Please sign in to comment.