From 4ead83f197eb8c7326b3b60eb25d59a8368e1001 Mon Sep 17 00:00:00 2001 From: david-sun-1 <138479116+david-sun-1@users.noreply.github.com> Date: Sat, 18 Nov 2023 06:19:28 +1100 Subject: [PATCH] Add functionality for early stopping rounds. (#193) * add functionality for early stopping * remove version word * evaluation msg into a parsing function and add back evaluation to updateone * Updated the call to updateone! to pass in the watchlist so it can be used by early stopping round logic. * Added comments, additional examples, fixed issues with watchlist ordering as a Dict. * Added functionality to extract the best iteration round with examples. Included additional test case coverage. * Cleaned up some lingering test cases. * Updated doc to include early stopping example. * Added additional info on data types for watchlist * Annotated OrderedDict to be more obvious. * Included using statement for OrderedCollection * Moved log message parsing to update! instead of updateone * Updated documentation and tests. * Altered the XGBoost method definition to reflect exception states for early stopping rounds and watchlist. * Created exception if extract_metric_value could not find a match when parsing XGBoost logs. --------- Co-authored-by: Wilan Wong Co-authored-by: wilan-wong-1 <148725847+wilan-wong-1@users.noreply.github.com> --- docs/src/index.md | 41 ++++++++++++ src/booster.jl | 160 ++++++++++++++++++++++++++++++++++++++++++---- test/runtests.jl | 117 +++++++++++++++++++++++++++++++++ 3 files changed, 305 insertions(+), 13 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 3e10ce8..a284f4f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -127,6 +127,7 @@ Unlike feature data, label data can be extracted after construction of the `DMat [`XGBoost.getlabel`](@ref). + ## Booster The [`Booster`](@ref) object holds model data. They are created with training data. Internally this is always a `DMatrix` but arguments will be automatically converted. @@ -182,3 +183,43 @@ is equivalent to bst = xgboost((X, y), num_round=10) update!(bst, (X, y), num_round=10) ``` + +### Early Stopping +To help prevent overfitting to the training set, it is helpful to use a validation set to evaluate against to ensure that the XGBoost iterations continue to generalise outside training loss reduction. Early stopping provides a convenient way to automatically stop the +boosting process if it's observed that the generalisation capability of the model does not improve for `k` rounds. + +If there is more than one element in watchlist, by default the last element will be used. In this case, you must use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary otherwise an exception will be generated. There will be +a warning if you want to execute early stopping mechanism (`early_stopping_rounds > 0`) but have provided a watchlist with type `Dict` with +more than 1 element. + +Similarly, if there is more than one element in eval_metric, by default the last element will be used. + +For example: + +```julia +using LinearAlgebra +using OrderedCollections + +𝒻(x) = 2norm(x)^2 - norm(x) + +X = randn(100,3) +y = 𝒻.(eachrow(X)) + +dtrain = DMatrix((X, y)) + +X_valid = randn(50,3) +y_valid = 𝒻.(eachrow(X_valid)) + +dvalid = DMatrix((X_valid, y_valid)) + +bst = xgboost(dtrain, num_round = 100, eval_metric = "rmse", watchlist = OrderedDict(["train" => dtrain, "eval" => dvalid]), early_stopping_rounds = 5, max_depth=6, η=0.3) + +# get the best iteration and use it for prediction +ŷ = predict(bst, X_valid, ntree_limit = bst.best_iteration) + +using Statistics +println("RMSE from model prediction $(round((mean((ŷ - y_valid).^2).^0.5), digits = 8)).") + +# we can also retain / use the best score (based on eval_metric) which is stored in the booster +println("Best RMSE from model training $(round((bst.best_score), digits = 8)).") +``` \ No newline at end of file diff --git a/src/booster.jl b/src/booster.jl index bebb054..32f94bb 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -1,4 +1,3 @@ - """ Booster @@ -50,11 +49,17 @@ mutable struct Booster # out what the hell is happening, it's never used for program logic params::Dict{Symbol,Any} - function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict()) + # store early stopping information + best_iteration::Union{Int64, Missing} + best_score::Union{Float64, Missing} + + function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict(), best_iteration::Union{Int64, Missing}=missing, + best_score::Union{Float64, Missing}=missing) finalizer(x -> xgbcall(XGBoosterFree, x.handle), new(h, fsn, params)) end end + """ setparam!(b::Booster, name, val) @@ -366,7 +371,6 @@ function updateone!(b::Booster, Xy::DMatrix; update_feature_names::Bool=false, ) xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle) - isempty(watchlist) || logeval(b, watchlist, round_number) _maybe_update_feature_names!(b, Xy, update_feature_names) b end @@ -382,7 +386,6 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr g = convert(Vector{Cfloat}, g) h = convert(Vector{Cfloat}, h) xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g)) - isempty(watchlist) || logeval(b, watchlist, round_number) _maybe_update_feature_names!(b, Xy, update_feature_names) b end @@ -422,14 +425,105 @@ Run `num_round` rounds of gradient boosting on [`Booster`](@ref) `b`. The first and second derivatives of the loss function (`ℓ′` and `ℓ″` respectively) can be provided for custom loss. """ -function update!(b::Booster, data, a...; num_round::Integer=1, kw...) +function update!(b::Booster, data, a...; + num_round::Integer=1, + watchlist::Any = Dict("train" => data), + early_stopping_rounds::Integer=0, + maximize=false, + kw..., + ) + + if !isempty(watchlist) && early_stopping_rounds > 0 + @info("Will train until there has been no improvement in $early_stopping_rounds rounds.\n") + best_round = 0 + best_score = maximize ? -Inf : Inf + end + for j ∈ 1:num_round round_number = getnrounds(b) + 1 - updateone!(b, data, a...; round_number, kw...) + + updateone!(b, data, a...; round_number, watchlist, kw...) + + # Evaluate if watchlist is not empty + if !isempty(watchlist) + msg = evaliter(b, watchlist, round_number) + @info msg + if early_stopping_rounds > 0 + score, dataset, metric = extract_metric_value(msg) + if (maximize && score > best_score || (!maximize && score < best_score)) + best_score = score + best_round = j + elseif j - best_round >= early_stopping_rounds + @info( + "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds." + ) + # add additional fields to record the best iteration + b.best_iteration = best_round + b.best_score = best_score + return b + end + end + end end b end + + +""" + extract_metric_value(msg, dataset=nothing, metric=nothing) + +Extracts a numeric value from a message based on the specified dataset and metric. +If dataset or metric is not provided, the function will automatically find the last +mentioned dataset or metric in the message. + +# Arguments +- `msg::AbstractString`: The message containing the numeric values. +- `dataset::Union{AbstractString, Nothing}`: The dataset to extract values for (default: `nothing`). +- `metric::Union{AbstractString, Nothing}`: The metric to extract values for (default: `nothing`). + +# Returns +- Returns the parsed Float64 value if a match is found, otherwise returns `nothing`. + +# Examples +```julia +msg = "train-rmsle:0.09516384803222511 train-rmse:0.12458323318968342 eval-rmsle:0.09311178520817574 eval-rmse:0.12088154560829874" + +# Without specifying dataset and metric +value_without_params = extract_metric_value(msg) +println(value_without_params) # Output: (0.09311178520817574, "eval", "rmsle") + +# With specifying dataset and metric +value_with_params = extract_metric_value(msg, "train", "rmsle") +println(value_with_params) # Output: (0.0951638480322251, "train", "rmsle") +""" + +function extract_metric_value(msg, dataset=nothing, metric=nothing) + if isnothing(dataset) + # Find the last mentioned dataset - whilst retaining order + datasets = unique([m.match for m in eachmatch(r"\w+(?=-)", msg)]) + dataset = last(collect(datasets)) + end + + if isnothing(metric) + # Find the last mentioned metric - whilst retaining order + metrics = unique([m.match for m in eachmatch(r"(?<=-)\w+", msg)]) + metric = last(collect(metrics)) + end + + pattern = Regex("$dataset-$metric:([\\d.]+)") + + match_result = match(pattern, msg) + + if match_result != nothing + parsed_value = parse(Float64, match_result.captures[1]) + return parsed_value, dataset, metric + end + + # there was no match result - should error out + error("No match found for pattern: $dataset-$metric in message: $msg") +end + """ xgboost(data; num_round=10, watchlist=Dict(), kw...) xgboost(data, ℓ′, ℓ″; kw...) @@ -439,7 +533,19 @@ This is essentially an alias for constructing a [`Booster`](@ref) with `data` an followed by [`update!`](@ref) for `nrounds`. `watchlist` is a dict the keys of which are strings giving the name of the data to watch -and the values of which are [`DMatrix`](@ref) objects containing the data. +and the values of which are [`DMatrix`](@ref) objects containing the data. It is mandatory to use an OrderedDict +when utilising early_stopping_rounds and there is more than 1 element in watchlist to ensure XGBoost uses the +correct and intended dataset to perform early stop. + +`early_stopping_rounds` activates early stopping if set to > 0. Validation metric needs to improve at +least once in every k rounds. If `watchlist` is not explicitly provided, it will use the training dataset +to evaluate the stopping criterion. Otherwise, it will use the last data element in `watchlist` and the +last metric in `eval_metric` (if more than one). Note that `watchlist` cannot be empty if +`early_stopping_rounds` is enabled. + +`maximize` If early_stopping_rounds is set, then this parameter must be set as well. +When it is false, it means the smaller the evaluation score the better. When set to true, +the larger the evaluation score the better. All other keyword arguments are passed to [`Booster`](@ref). With few exceptions these are model training hyper-parameters, see [here](https://xgboost.readthedocs.io/en/stable/parameter.html) for @@ -450,23 +556,51 @@ See [`updateone!`](@ref) for more details. ## Examples ```julia +# Example 1: Basic usage of XGBoost (X, y) = (randn(100,3), randn(100)) -b = xgboost((X, y), 10, max_depth=10, η=0.1) +b = xgboost((X, y), num_round=10, max_depth=10, η=0.1) ŷ = predict(b, X) + +# Example 2: Using early stopping (using a validation set) with a watchlist +dtrain = DMatrix((randn(100,3), randn(100))) +dvalid = DMatrix((randn(100,3), randn(100))) + +watchlist = OrderedDict(["train" => dtrain, "valid" => dvalid]) + +b = xgboost(dtrain, num_round=10, early_stopping_rounds = 2, watchlist = watchlist, max_depth=10, η=0.1) + +# note that ntree_limit in the predict function helps assign the upper bound for iteration_range in the XGBoost API 1.4+ +ŷ = predict(b, dvalid, ntree_limit = b.best_iteration) ``` """ function xgboost(dm::DMatrix, a...; - num_round::Integer=10, - watchlist=Dict("train"=>dm), - kw... - ) + num_round::Integer=10, + watchlist::AbstractDict = Dict("train" => dm), + early_stopping_rounds::Integer=0, + maximize=false, + kw... + ) + Xy = DMatrix(dm) b = Booster(Xy; kw...) + + # We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1 + if isa(watchlist, Dict) + if early_stopping_rounds > 0 && length(watchlist) > 1 + error("You must supply an OrderedDict type for watchlist if early stopping rounds is enabled and there is more than one element in watchlist.") + end + end + + if isempty(watchlist) && early_stopping_rounds > 0 + error("Watchlist must be supplied if early_stopping_rounds is enabled.") + end + isempty(watchlist) || @info("XGBoost: starting training.") - update!(b, Xy, a...; num_round, watchlist) + update!(b, Xy, a...; num_round, watchlist, early_stopping_rounds, maximize) isempty(watchlist) || @info("Training rounds complete.") b end + xgboost(data, a...; kw...) = xgboost(DMatrix(data), a...; kw...) diff --git a/test/runtests.jl b/test/runtests.jl index cade6d5..844b255 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using CUDA: has_cuda, cu import Term using Random, SparseArrays using Test +using OrderedCollections include("utils.jl") @@ -130,6 +131,122 @@ end end end + +@testset "Early Stopping rounds" begin + + dtrain = XGBoost.load(DMatrix, testfilepath("agaricus.txt.train"), format=:libsvm) + dtest = XGBoost.load(DMatrix, testfilepath("agaricus.txt.test"), format=:libsvm) + # test the early stopping rounds interface with a Dict data type in the watchlist + watchlist = Dict("eval"=>dtest, "train"=>dtrain) + + bst = xgboost(dtrain, + num_round=30, + watchlist=watchlist, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"] + ) + + # test if it ran all the way till the end (baseline) + nrounds_bst = XGBoost.getnrounds(bst) + @test nrounds_bst == 30 + + let err = nothing + try + # Check to see that xgboost will error out when watchlist supplied is a dictionary with early_stopping_rounds enabled + bst_early_stopping = xgboost(dtrain, + num_round=30, + watchlist=watchlist, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) + + nrounds_bst = XGBoost.getnrounds(bst) + nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping) + catch err + end + + @test err isa Exception + end + + # test the early stopping rounds interface with an OrderedDict data type in the watchlist + watchlist_ordered = OrderedDict("train"=>dtrain, "eval"=>dtest) + + bst_early_stopping = xgboost(dtrain, + num_round=30, + watchlist=watchlist_ordered, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) + + + + @test XGBoost.getnrounds(bst_early_stopping) > 2 + @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst + + # get the rmse difference for the dtest + ŷ = predict(bst_early_stopping, dtest, ntree_limit = bst_early_stopping.best_iteration) + + filename = "agaricus.txt.test" + lines = readlines(testfilepath(filename)) + y = [parse(Float64,split(s)[1]) for s in lines] + + function calc_rmse(y_true::Vector{T}, y_pred::Vector{T}) where T <: Float64 + return sqrt(sum((y_true .- y_pred).^2)/length(y_true)) + end + + calc_metric = calc_rmse(Float64.(y), Float64.(ŷ)) + + # ensure that the results are the same (as numerically possible) with the best round + @test abs(bst_early_stopping.best_score - calc_metric) < 1e-9 + + # test the early stopping rounds interface with an OrderedDict data type in the watchlist using num_parallel_tree parameter + # this will test the XGBoost API for iteration_range is being utilised properly + watchlist_ordered = OrderedDict("train"=>dtrain, "eval"=>dtest) + + bst_early_stopping = xgboost(dtrain, + num_round=30, + watchlist=watchlist_ordered, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2, + num_parallel_tree = 10, + colsample_bylevel = 0.5 + ) + + @test XGBoost.getnrounds(bst_early_stopping) > 2 + @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst + + # get the rmse difference for the dtest + ŷ = predict(bst_early_stopping, dtest, ntree_limit = bst_early_stopping.best_iteration) + calc_metric = calc_rmse(Float64.(y), Float64.(ŷ)) + + # ensure that the results are the same (as numerically possible) with the best round + @test abs(bst_early_stopping.best_score - calc_metric) < 1e-9 + + # Test the interface with no watchlist provided (it'll default to training watchlist) + let err = nothing + try + bst_early_stopping = xgboost(dtrain, + num_round=30, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) + catch err + end + + @test !(err isa Exception) + end +end + + @testset "Blobs training" begin (X, y) = load_classification()