From 489d9794363e4678972f268f9116c548adf924c5 Mon Sep 17 00:00:00 2001 From: Michael Mahoney Date: Wed, 7 Dec 2022 10:05:24 -0500 Subject: [PATCH] Add get_rsplit() helper (#399) * Add get_rsplit() helper * Add to pkgdown index * avoid potential awkwardness down the line of adding a method for some other class than `rset` but the main arg being named `rset` * make the snaps easier to read by adding the call that generated the error Co-authored-by: Hannah Frick --- NAMESPACE | 3 ++ R/misc.R | 59 +++++++++++++++++++++++++++++++++++ _pkgdown.yml | 1 + man/get_rsplit.Rd | 34 ++++++++++++++++++++ tests/testthat/_snaps/misc.md | 35 +++++++++++++++++++++ tests/testthat/test-misc.R | 27 ++++++++++++++++ 6 files changed, 159 insertions(+) create mode 100644 man/get_rsplit.Rd diff --git a/NAMESPACE b/NAMESPACE index e7cd056c..e9b66cd7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -14,6 +14,8 @@ S3method(complement,sliding_index_split) S3method(complement,sliding_period_split) S3method(complement,sliding_window_split) S3method(dim,rsplit) +S3method(get_rsplit,default) +S3method(get_rsplit,rset) S3method(labels,rset) S3method(labels,rsplit) S3method(labels,vfold_cv) @@ -317,6 +319,7 @@ export(ends_with) export(everything) export(form_pred) export(gather) +export(get_rsplit) export(group_bootstraps) export(group_initial_split) export(group_mc_cv) diff --git a/R/misc.R b/R/misc.R index fb4273e8..d19a96da 100644 --- a/R/misc.R +++ b/R/misc.R @@ -278,3 +278,62 @@ non_random_classes <- c( "rolling_origin", "validation_time_split" ) + +#' Retrieve individual rsplits objects from an rset +#' +#' @param x The `rset` object to retrieve an rsplit from. +#' @param index An integer indicating which rsplit to retrieve: `1` for the +#' rsplit in the first row of the rset, `2` for the second, and so on. +#' @inheritParams rlang::args_dots_empty +#' +#' @return The rsplit object in row `index` of `rset` +#' +#' @examples +#' set.seed(123) +#' (starting_splits <- group_vfold_cv(mtcars, cyl, v = 3)) +#' get_rsplit(starting_splits, 1) +#' +#' @rdname get_rsplit +#' @export +get_rsplit <- function(x, index, ...) { + UseMethod("get_rsplit") +} + +#' @rdname get_rsplit +#' @export +get_rsplit.rset <- function(x, index, ...) { + rlang::check_dots_empty() + + n_rows <- nrow(x) + + acceptable_index <- length(index) == 1 && + rlang::is_integerish(index) && + index > 0 && + index <= n_rows + + if (!acceptable_index) { + msg <- ifelse( + length(index) != 1, + glue::glue("Index was of length {length(index)}."), + glue::glue("A value of {index} was provided.") + ) + + rlang::abort( + c( + glue::glue("`index` must be a length-1 integer between 1 and {n_rows}."), + x = msg + ) + ) + } + + x$splits[[index]] +} + +#' @rdname get_rsplit +#' @export +get_rsplit.default <- function(x, index, ...) { + cls <- paste0("'", class(x), "'", collapse = ", ") + rlang::abort( + paste("No `get_rsplit()` method for this class(es)", cls) + ) +} diff --git a/_pkgdown.yml b/_pkgdown.yml index b4fb6c84..5364308c 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -62,6 +62,7 @@ reference: - add_resample_id - complement - form_pred + - get_rsplit - starts_with("labels") - make_splits - make_strata diff --git a/man/get_rsplit.Rd b/man/get_rsplit.Rd new file mode 100644 index 00000000..3eb81ae3 --- /dev/null +++ b/man/get_rsplit.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/misc.R +\name{get_rsplit} +\alias{get_rsplit} +\alias{get_rsplit.rset} +\alias{get_rsplit.default} +\title{Retrieve individual rsplits objects from an rset} +\usage{ +get_rsplit(x, index, ...) + +\method{get_rsplit}{rset}(x, index, ...) + +\method{get_rsplit}{default}(x, index, ...) +} +\arguments{ +\item{x}{The \code{rset} object to retrieve an rsplit from.} + +\item{index}{An integer indicating which rsplit to retrieve: \code{1} for the +rsplit in the first row of the rset, \code{2} for the second, and so on.} + +\item{...}{These dots are for future extensions and must be empty.} +} +\value{ +The rsplit object in row \code{index} of \code{rset} +} +\description{ +Retrieve individual rsplits objects from an rset +} +\examples{ +set.seed(123) +(starting_splits <- group_vfold_cv(mtcars, cyl, v = 3)) +get_rsplit(starting_splits, 1) + +} diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index 2f7749e0..525a5567 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -140,3 +140,38 @@ `rset` must be an rset object +# get_rsplit() + + Code + get_rsplit(val, 3) + Condition + Error in `get_rsplit()`: + ! `index` must be a length-1 integer between 1 and 1. + x A value of 3 was provided. + +--- + + Code + get_rsplit(val, c(1, 2)) + Condition + Error in `get_rsplit()`: + ! `index` must be a length-1 integer between 1 and 1. + x Index was of length 2. + +--- + + Code + get_rsplit(val, 1.5) + Condition + Error in `get_rsplit()`: + ! `index` must be a length-1 integer between 1 and 1. + x A value of 1.5 was provided. + +--- + + Code + get_rsplit(warpbreaks, 1) + Condition + Error in `get_rsplit()`: + ! No `get_rsplit()` method for this class(es) 'data.frame' + diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index afa69b31..9c6fa573 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -141,3 +141,30 @@ test_that("reshuffle_rset is working", { expect_snapshot_error(reshuffle_rset(rset_subclasses[["manual_rset"]]$splits[[1]])) }) + +test_that("get_rsplit()", { + + val <- withr::with_seed( + 11, + validation_split(warpbreaks) + ) + + expect_identical(val$splits[[1]], get_rsplit(val, 1)) + + expect_snapshot(error = TRUE,{ + get_rsplit(val, 3) + }) + + expect_snapshot(error = TRUE,{ + get_rsplit(val, c(1, 2)) + }) + + expect_snapshot(error = TRUE,{ + get_rsplit(val, 1.5) + }) + + expect_snapshot(error = TRUE,{ + get_rsplit(warpbreaks, 1) + }) + +})