From 34e506ea7cf5ee88434495964aeafd3f9710f143 Mon Sep 17 00:00:00 2001 From: Istvan Bartha <106101+pityka@users.noreply.github.com> Date: Fri, 9 Feb 2024 12:15:53 +0100 Subject: [PATCH] extratree - always split on missing if missing data is present --- .../main/scala/lamp/forest/extratrees.scala | 17 +- .../src/main/scala/lamp/forest/package.scala | 258 ++++++++---------- .../scala/lamp/forest/extratree.test.scala | 24 +- 3 files changed, 143 insertions(+), 156 deletions(-) diff --git a/extratrees/src/main/scala/lamp/forest/extratrees.scala b/extratrees/src/main/scala/lamp/forest/extratrees.scala index f21b737b..7c30016d 100644 --- a/extratrees/src/main/scala/lamp/forest/extratrees.scala +++ b/extratrees/src/main/scala/lamp/forest/extratrees.scala @@ -11,15 +11,16 @@ object ClassificationLeaf { import upickle.default.{ReadWriter => RW, macroRW} implicit val rw: RW[ClassificationLeaf] = macroRW } +/** Left bin contains elements < than cutpoint XOR missing elements if cutpoint is empty + */ case class ClassificationNonLeaf( - left: ClassificationTree, + left: ClassificationTree, right: ClassificationTree, splitFeature: Int, - cutpoint: Double, - splitMissingIsLess: Boolean + cutpoint: Option[Double], ) extends ClassificationTree { override def toString = - s"ClassificationTree(left=$left,right=$right,splitFeatures=$splitFeature,cutpoint=$cutpoint,splitMissingIsLess=$splitMissingIsLess)" + s"ClassificationTree(left=$left,right=$right,splitFeatures=$splitFeature,cutpoint=$cutpoint)" } object ClassificationNonLeaf { import upickle.default.{ReadWriter => RW, macroRW} @@ -37,15 +38,17 @@ case class RegressionLeaf(targetMean: Double) extends RegressionTree { override def toString = s"RegressionLeaf(targetMean=$targetMean)" } +/** Left bin contains elements < than cutpoint XOR missing elements if cutpoint is empty + */ case class RegressionNonLeaf( left: RegressionTree, right: RegressionTree, splitFeature: Int, - cutpoint: Double, - splitMissingIsLess: Boolean + cutpoint: Option[Double], + ) extends RegressionTree { override def toString = - s"RegressionNonLeaf(left=$left,right=$right,splitFeatures=$splitFeature,cutpoint=$cutpoint,splitMissingIsLess=$splitMissingIsLess)" + s"RegressionNonLeaf(left=$left,right=$right,splitFeatures=$splitFeature,cutpoint=$cutpoint)" } object RegressionLeaf { import upickle.default.{ReadWriter => RW, macroRW} diff --git a/extratrees/src/main/scala/lamp/forest/package.scala b/extratrees/src/main/scala/lamp/forest/package.scala index 565db3a7..28353418 100644 --- a/extratrees/src/main/scala/lamp/forest/package.scala +++ b/extratrees/src/main/scala/lamp/forest/package.scala @@ -10,15 +10,15 @@ package object extratrees { private[lamp] def lessThanCutpoint( v: Vec[Double], cutpoint: Double, - missingIsLess: Boolean + takeMissing: Boolean ) = { val l = v.length val ar = new Array[Boolean](l) var i = 0 - if (missingIsLess) { + if (takeMissing) { while (i < l) { val x = v.raw(i) - ar(i) = x.isNaN || x < cutpoint + ar(i) = x.isNaN i += 1 } } else { @@ -28,7 +28,7 @@ package object extratrees { i += 1 } } - Vec(ar) + (Vec(ar)) } private[lamp] def minmax(self: Vec[Double]) = { @@ -80,7 +80,7 @@ package object extratrees { var bestScore = Double.NegativeInfinity var bestFeature = -1 var bestCutpoint = Double.NaN - var bestMissingIsLess = false + var bestSplitOnMissing = false var visited = 0 while (N - high < k && high - low > 0) { val r = rng.nextInt(low, high - 1) @@ -91,44 +91,26 @@ package object extratrees { low += 1 } else { val col = takeCol(data, subset, attr) - val cutpoint = { + val cutpoint = if (hasMissing) Double.NaN else { def score(candidateCutPoint: Int): Double = { val cutpoint = col.raw(candidateCutPoint) - val takeMissingIsLess = - if (!hasMissing) null - else - lessThanCutpoint(col, cutpoint, true) - val takeMissingIsNotLess = + + val takeLessThanCutpoint = lessThanCutpoint(col, cutpoint, false) - val scoreMissingIsLess = - if (!hasMissing) Double.NaN - else - giniScore( - targetAtSubset, - weightsAtSubset, - takeMissingIsLess, - giniTotal, - numClasses, - buf1, - buf2 - ) + val scoreMissingIsNotLess = giniScore( targetAtSubset, weightsAtSubset, - takeMissingIsNotLess, + takeLessThanCutpoint, giniTotal, numClasses, buf1, buf2 ) - val chosenMissingIsLess = - (!scoreMissingIsLess.isNaN && (scoreMissingIsLess > scoreMissingIsNotLess || scoreMissingIsNotLess.isNaN)) - val chosenScore = - if (chosenMissingIsLess) scoreMissingIsLess - else scoreMissingIsNotLess - chosenScore + + scoreMissingIsNotLess } var i = 0 val n = col.length @@ -145,45 +127,46 @@ package object extratrees { col(maxi) } - val takeMissingIsLess = + val (takeMissing) = if (!hasMissing) null else lessThanCutpoint(col, cutpoint, true) - val takeMissingIsNotLess = - lessThanCutpoint(col, cutpoint, false) + val takeLessThanCutpoint = if (hasMissing) null + else lessThanCutpoint(col, cutpoint, false) - val scoreMissingIsLess = + val scoreSplitOnMissing = if (!hasMissing) Double.NaN else giniScore( targetAtSubset, weightsAtSubset, - takeMissingIsLess, + takeMissing, giniTotal, numClasses, buf1, buf2 ) - val scoreMissingIsNotLess = giniScore( + val scoreSplitOnCutpoint = if (hasMissing) Double.NaN else giniScore( targetAtSubset, weightsAtSubset, - takeMissingIsNotLess, + takeLessThanCutpoint, giniTotal, numClasses, buf1, buf2 ) - val chosenMissingIsLess = - (!scoreMissingIsLess.isNaN && (scoreMissingIsLess > scoreMissingIsNotLess || scoreMissingIsNotLess.isNaN)) + val chosenSplitOnMissing = + (!scoreSplitOnMissing.isNaN && (scoreSplitOnMissing > scoreSplitOnCutpoint || scoreSplitOnCutpoint.isNaN)) val chosenScore = - if (chosenMissingIsLess) scoreMissingIsLess else scoreMissingIsNotLess + if (chosenSplitOnMissing) scoreSplitOnMissing + else scoreSplitOnCutpoint if (chosenScore > bestScore) { bestScore = chosenScore bestFeature = attr bestCutpoint = cutpoint - bestMissingIsLess = chosenMissingIsLess + bestSplitOnMissing = chosenSplitOnMissing } if (chosenScore.isNaN) { swap(r, low) @@ -196,9 +179,9 @@ package object extratrees { } } if (visited == 0 || bestCutpoint.isNaN) - (-1, bestCutpoint, low, bestMissingIsLess) + (-1, if (bestSplitOnMissing) None else Some(bestCutpoint), low) else - (bestFeature, bestCutpoint, low, bestMissingIsLess) + (bestFeature, if (bestSplitOnMissing) None else Some(bestCutpoint), low) } private[lamp] def splitClassification( data: Mat[Double], @@ -227,7 +210,7 @@ package object extratrees { var bestScore = Double.NegativeInfinity var bestFeature = -1 var bestCutpoint = Double.NaN - var bestMissingIsLess = false + var bestSplitOnMissing = false var visited = 0 while (N - high < k && high - low > 0) { val r = rng.nextInt(low, high - 1) @@ -240,45 +223,46 @@ package object extratrees { val cutpoint = rng.nextDouble(min, max) val col = takeCol(data, subset, attr) - val takeMissingIsLess = + val (takeMissing) = if (!hasMissing) null else lessThanCutpoint(col, cutpoint, true) - val takeMissingIsNotLess = + val takeLessThanCutpoint = lessThanCutpoint(col, cutpoint, false) - val scoreMissingIsLess = + val scoreSplitOnMissing = if (!hasMissing) Double.NaN else giniScore( targetAtSubset, weightsAtSubset, - takeMissingIsLess, + takeMissing, giniTotal, numClasses, buf1, buf2 ) - val scoreMissingIsNotLess = giniScore( + val scoreSplitOnCutpoint = if (hasMissing) Double.NaN else giniScore( targetAtSubset, weightsAtSubset, - takeMissingIsNotLess, + takeLessThanCutpoint, giniTotal, numClasses, buf1, buf2 ) - val chosenMissingIsLess = - (!scoreMissingIsLess.isNaN && (scoreMissingIsLess > scoreMissingIsNotLess || scoreMissingIsNotLess.isNaN)) + val chosenSplitOnMissing = + (!scoreSplitOnMissing.isNaN && (scoreSplitOnMissing > scoreSplitOnCutpoint || scoreSplitOnCutpoint.isNaN)) val chosenScore = - if (chosenMissingIsLess) scoreMissingIsLess else scoreMissingIsNotLess + if (chosenSplitOnMissing) scoreSplitOnMissing + else scoreSplitOnCutpoint if (chosenScore > bestScore) { bestScore = chosenScore bestFeature = attr bestCutpoint = cutpoint - bestMissingIsLess = chosenMissingIsLess + bestSplitOnMissing = chosenSplitOnMissing } if (chosenScore.isNaN) { swap(r, low) @@ -291,9 +275,9 @@ package object extratrees { } } if (visited == 0 || bestCutpoint.isNaN) - (-1, bestCutpoint, low, bestMissingIsLess) + (-1, None, low) else - (bestFeature, bestCutpoint, low, bestMissingIsLess) + (bestFeature, if (bestSplitOnMissing) None else Some(bestCutpoint), low) } private[lamp] def splitBestRegression( data: Mat[Double], @@ -333,33 +317,18 @@ package object extratrees { val cutpoint = { def score(candidateCutPoint: Int): Double = { val cutpoint = col.raw(candidateCutPoint) - val takeMissingIsLess = - if (!hasMissing) null - else - lessThanCutpoint(col, cutpoint, true) + val takeMissingIsNotLess = lessThanCutpoint(col, cutpoint, false) - val scoreMissingIsLess = - if (!hasMissing) Double.NaN - else - computeVarianceReduction( - targetAtSubset, - takeMissingIsLess, - varianceNoSplit - ) - val scoreMissingIsNotLess = computeVarianceReduction( + + computeVarianceReduction( targetAtSubset, takeMissingIsNotLess, varianceNoSplit ) - val chosenMissingIsLess = - (!scoreMissingIsLess.isNaN && (scoreMissingIsLess > scoreMissingIsNotLess || scoreMissingIsNotLess.isNaN)) - val chosenScore = - if (chosenMissingIsLess) scoreMissingIsLess - else scoreMissingIsNotLess - chosenScore + } var i = 0 val n = col.length @@ -390,7 +359,7 @@ package object extratrees { takeMissingIsLess, varianceNoSplit ) - val scoreMissingIsNotLess = computeVarianceReduction( + val scoreMissingIsNotLess =if (hasMissing) Double.NaN else computeVarianceReduction( targetAtSubset, takeMissingIsNotLess, varianceNoSplit @@ -419,9 +388,9 @@ package object extratrees { } } if (visited == 0 || bestCutpoint.isNaN) - (-1, bestCutpoint, low, bestMissingIsLess) + (-1, if (bestMissingIsLess) None else Some(bestCutpoint), low) else - (bestFeature, bestCutpoint, low, bestMissingIsLess) + (bestFeature, if (bestMissingIsLess) None else Some(bestCutpoint), low) } private[lamp] def splitRegression( @@ -460,7 +429,7 @@ package object extratrees { } else { val cutpoint = rng.nextDouble(min, max) val col = takeCol(data, subset, attr) - val takeMissingIsLess = + val (takeMissingIsLess) = if (!hasMissing) null else lessThanCutpoint(col, cutpoint, true) @@ -469,13 +438,13 @@ package object extratrees { val scoreMissingIsLess = if (!hasMissing) Double.NaN - else + else computeVarianceReduction( targetAtSubset, takeMissingIsLess, varianceNoSplit ) - val scoreMissingIsNotLess = computeVarianceReduction( + val scoreMissingIsNotLess = if (hasMissing) Double.NaN else computeVarianceReduction( targetAtSubset, takeMissingIsNotLess, varianceNoSplit @@ -504,9 +473,9 @@ package object extratrees { } } if (visited == 0 || bestCutpoint.isNaN) - (-1, bestCutpoint, low, bestMissingIsLess) + (-1, if (bestMissingIsLess) None else Some(bestCutpoint), low) else - (bestFeature, bestCutpoint, low, bestMissingIsLess) + (bestFeature, if (bestMissingIsLess) None else Some(bestCutpoint), low) } @@ -520,13 +489,21 @@ package object extratrees { left, right, splitFeature, - cutpoint, - missingIsLess + Some(cutpoint) + ) => + if (sample.raw(splitFeature) < cutpoint) traverse(left) + else traverse(right) + + case ClassificationNonLeaf( + left, + right, + splitFeature, + None ) => if ( - sample.raw(splitFeature) < cutpoint || (missingIsLess && sample + sample .raw(splitFeature) - .isNaN) + .isNaN ) traverse(left) else traverse(right) } @@ -560,13 +537,20 @@ package object extratrees { left, right, splitFeature, - cutpoint, - missingIsLess + Some(cutpoint) + ) => + if (sample.raw(splitFeature) < cutpoint) traverse(left) + else traverse(right) + case RegressionNonLeaf( + left, + right, + splitFeature, + None ) => if ( - sample.raw(splitFeature) < cutpoint || (missingIsLess && sample + sample .raw(splitFeature) - .isNaN) + .isNaN ) traverse(left) else traverse(right) } @@ -785,15 +769,13 @@ package object extratrees { leftTree: RegressionTree, rightTree: RegressionTree, splitFeatureIdx: Int, - splitCutpoint: Double, - splitMissingIsLess: Boolean + splitCutpoint: Option[Double] ) = RegressionNonLeaf( leftTree, rightTree, splitFeatureIdx, - splitCutpoint, - splitMissingIsLess + splitCutpoint ) def targetIsConstant = { @@ -814,7 +796,7 @@ package object extratrees { else if (targetIsConstant) makeLeaf else { - val (splitFeatureIdx, splitCutpoint, nConstant2, missingIsLess) = + val (splitFeatureIdx, splitCutpoint, nConstant2) = if (bestSplit) splitBestRegression( data, @@ -841,19 +823,20 @@ package object extratrees { val splitFeature = col(data, splitFeatureIdx) val leftSubset = - if (missingIsLess) - subset.filter(s => - splitFeature.raw(s) < splitCutpoint || splitFeature.raw(s).isNaN - ) - else subset.filter(s => splitFeature.raw(s) < splitCutpoint) + if (splitCutpoint.isEmpty) + subset.filter(s => splitFeature.raw(s).isNaN) + else { + val c = splitCutpoint.get + subset.filter(s => splitFeature.raw(s) < c) + } val rightSubset = - if (missingIsLess) - subset.filter(s => splitFeature.raw(s) >= splitCutpoint) - else - subset.filter(s => - splitFeature.raw(s) >= splitCutpoint || splitFeature.raw(s).isNaN - ) + if (splitCutpoint.isEmpty) + subset.filter(s => !splitFeature.raw(s).isNaN) + else { + val c = splitCutpoint.get + subset.filter(s => splitFeature.raw(s) >= c) + } val leftTree = buildTreeRegression( @@ -887,8 +870,7 @@ package object extratrees { leftTree, rightTree, splitFeatureIdx, - splitCutpoint, - missingIsLess + splitCutpoint ) } } @@ -966,15 +948,13 @@ package object extratrees { leftTree: ClassificationTree, rightTree: ClassificationTree, splitFeatureIdx: Int, - splitCutpoint: Double, - splitMissingIsLess: Boolean + splitCutpoint: Option[Double] ) = ClassificationNonLeaf( leftTree, rightTree, splitFeatureIdx, - splitCutpoint, - splitMissingIsLess + splitCutpoint ) def targetIsConstant = { val col = targetInSubset @@ -994,7 +974,7 @@ package object extratrees { else if (targetIsConstant) makeLeaf else { - val (splitFeatureIdx, splitCutpoint, numConstant2, missingIsLess) = + val (splitFeatureIdx, splitCutpoint, numConstant2) = if (bestSplit) splitBestClassification( data, @@ -1024,19 +1004,20 @@ package object extratrees { val splitFeature = col(data, splitFeatureIdx) val leftSubset = - if (missingIsLess) - subset.filter(s => - splitFeature.raw(s) < splitCutpoint || splitFeature.raw(s).isNaN - ) - else subset.filter(s => splitFeature.raw(s) < splitCutpoint) + if (splitCutpoint.isEmpty) + subset.filter(s => splitFeature.raw(s).isNaN) + else { + val c = splitCutpoint.get + subset.filter(s => splitFeature.raw(s) < c) + } val rightSubset = - if (missingIsLess) - subset.filter(s => splitFeature.raw(s) >= splitCutpoint) - else - subset.filter(s => - splitFeature.raw(s) >= splitCutpoint || splitFeature.raw(s).isNaN - ) + if (splitCutpoint.isEmpty) + subset.filter(s => !splitFeature.raw(s).isNaN) + else { + val c = splitCutpoint.get + subset.filter(s => splitFeature.raw(s) >= c) + } val leftTree = buildTreeClassification( @@ -1074,8 +1055,7 @@ package object extratrees { leftTree, rightTree, splitFeatureIdx, - splitCutpoint, - missingIsLess + splitCutpoint ) } } @@ -1091,8 +1071,9 @@ package object extratrees { val bufF = Buffer.empty[T](m) while (i < n) { val v: T = vec.raw(i) - if (pred(i)) bufT.+=(v) + if (pred(i) ) bufT.+=(v) else bufF.+=(v) + i += 1 } (Vec(bufT.toArray), Vec(bufF.toArray)) @@ -1108,7 +1089,8 @@ package object extratrees { buf2: Array[Double] ) = { val numSamplesNoSplit = - if (sampleWeights.isEmpty) samplesInSplit.length.toDouble + if (sampleWeights.isEmpty) + samplesInSplit.length.toDouble else sampleWeights.get.sum2 var i = 0 var targetInCount = 0.0 @@ -1120,13 +1102,13 @@ package object extratrees { if (sampleWeights.isEmpty) { while (i < n) { val v: Int = target.raw(i) - if (samplesInSplit.raw(i)) { + if ( samplesInSplit.raw(i)) { targetInCount += 1 distributionIn(v) += 1d - } else { + } else { targetOutCount += 1 distributionOut(v) += 1d - } + } i += 1 } } else { @@ -1134,13 +1116,13 @@ package object extratrees { while (i < n) { val v: Int = target.raw(i) val ww = weights.raw(i) - if (samplesInSplit.raw(i)) { + if ( samplesInSplit.raw(i)) { targetInCount += ww distributionIn(v) += ww - } else { + } else { targetOutCount += ww distributionOut(v) += ww - } + } i += 1 } } @@ -1209,9 +1191,9 @@ package object extratrees { else targetOutSplit.sampleVariance * (targetOutSplit.length - 1d) / (targetOutSplit.length) - val numSamplesNoSplit = - target.length.toDouble - + val numSamplesNoSplit = + target.length.toDouble + (varianceNoSplit - (targetInSplit.length.toDouble / numSamplesNoSplit.toDouble) * varianceInSplit - (targetOutSplit.length.toDouble / numSamplesNoSplit.toDouble) * varianceOutSplit) / varianceNoSplit diff --git a/extratrees/src/test/scala/lamp/forest/extratree.test.scala b/extratrees/src/test/scala/lamp/forest/extratree.test.scala index b2289b0c..9d21658e 100644 --- a/extratrees/src/test/scala/lamp/forest/extratree.test.scala +++ b/extratrees/src/test/scala/lamp/forest/extratree.test.scala @@ -13,6 +13,7 @@ class ExtraTreesSuite extends AnyFunSuite { ) assert(r == 0.999999495454448) } + test("gini impurity") { val t = Vec(1, 1, 0, 0) val gt = giniImpurity(t, None, 2) @@ -28,6 +29,7 @@ class ExtraTreesSuite extends AnyFunSuite { ) == 0.5 ) } + test("gini impurity weighted") { val t = Vec(1, 1, 0, 0) val gt = giniImpurity(t, Some(vec.ones(4)), 2) @@ -84,7 +86,7 @@ class ExtraTreesSuite extends AnyFunSuite { targetAtSubset = Vec(0d, 0.1d, 100d, 100.1, 100.2), rng = org.saddle.spire.random.rng.Cmwc5.fromTime(0L) ) - assert(r == ((0, 3.424021023861243, 0, false))) + assert(r == ((0, Some(3.424021023861243), 0))) } test("splitBestRegression 1") { val attr = Array(0, 1) @@ -97,7 +99,7 @@ class ExtraTreesSuite extends AnyFunSuite { targetAtSubset = Vec(0d, 0.1d, 0.1, 100.1, 100.2), rng = org.saddle.spire.random.rng.Cmwc5.fromTime(0L) ) - assert(r == ((0, 4d, 0, false))) + assert(r == ((0, Some(4d), 0))) } test("splitClassification 1") { val attr = Array(0, 1) @@ -113,7 +115,7 @@ class ExtraTreesSuite extends AnyFunSuite { rng = org.saddle.spire.random.rng.Cmwc5.fromTime(0L) ) assert(attr.toVector == Vector(1, 0)) - assert(r == ((0, 3.424021023861243, 0, false))) + assert(r == ((0, Some(3.424021023861243), 0))) } test("splitBestClassification") { val attr = Array(0, 1) @@ -129,7 +131,7 @@ class ExtraTreesSuite extends AnyFunSuite { rng = org.saddle.spire.random.rng.Cmwc5.fromTime(0L) ) assert(attr.toVector == Vector(1, 0)) - assert(r == ((0, 4d, 0, false))) + assert(r == ((0, Some(4d), 0))) } test("splitClassification 1 weighted") { val attr = Array(0, 1) @@ -145,7 +147,7 @@ class ExtraTreesSuite extends AnyFunSuite { rng = org.saddle.spire.random.rng.Cmwc5.fromTime(0L) ) assert(attr.toVector == Vector(1, 0)) - assert(r == ((0, 3.424021023861243, 0, false))) + assert(r == ((0, Some(3.424021023861243), 0))) } test("splitClassification 1 0-weighted") { val attr = Array(0, 1) @@ -176,7 +178,7 @@ class ExtraTreesSuite extends AnyFunSuite { weightsAtSubset = None, rng = org.saddle.spire.random.rng.Cmwc5.fromTime(0L) ) - assert(r == ((1, 97.54668482609304, 1, false))) + assert(r == ((1, Some(97.54668482609304), 1))) assert(attr.toVector == Vector(0, 1)) } test("splitClassification 3") { @@ -196,7 +198,7 @@ class ExtraTreesSuite extends AnyFunSuite { weightsAtSubset = None, rng = org.saddle.spire.random.rng.Cmwc5.fromTime(0L) ) - assert(r == ((2, 97.54668482609304, 2, false))) + assert(r == ((2, Some(97.54668482609304), 2))) assert(attr.toVector == Vector(0, 1, 2)) } test("splitClassification 4") { @@ -216,7 +218,7 @@ class ExtraTreesSuite extends AnyFunSuite { weightsAtSubset = None, rng = org.saddle.spire.random.rng.Cmwc5.fromTime(0L) ) - assert(r == ((2, 97.54668482609304, 1, false))) + assert(r == ((2, Some(97.54668482609304), 1))) assert(attr.toVector == Vector(1, 0, 2)) } test("splitClassification 5") { @@ -236,7 +238,7 @@ class ExtraTreesSuite extends AnyFunSuite { weightsAtSubset = None, rng = org.saddle.spire.random.rng.Cmwc5.fromTime(1L) ) - assert(r == ((2, 97.84900936098786, 1, false))) + assert(r == ((2, Some(97.84900936098786), 1))) assert(attr.toVector == Vector(0, 1, 2)) } test("splitClassification 6") { @@ -256,7 +258,7 @@ class ExtraTreesSuite extends AnyFunSuite { weightsAtSubset = None, rng = org.saddle.spire.random.rng.Cmwc5.fromTime(123L) ) - assert(r == ((2, 96.07259095141863, 2, false))) + assert(r == ((2, Some(96.07259095141863), 2))) assert(attr.toVector == Vector(0, 1, 2)) } test("splitClassification 7") { @@ -276,7 +278,7 @@ class ExtraTreesSuite extends AnyFunSuite { weightsAtSubset = None, rng = org.saddle.spire.random.rng.Cmwc5.fromTime(123L) ) - assert(r == ((2, 96.07259095141863, 2, false))) + assert(r == ((2, Some(96.07259095141863), 2))) assert(attr.toVector == Vector(1, 0, 2)) }