Skip to content

prior_terminal_node_expo hyperparameter has incorrect default range for dbarts #251

Closed
@n8layman

Description

@n8layman

The problem

The prior_terminal_node_expo (power) hyperparameter default currently has a range of [0,3]. However an exponent less than one here can lead to intensive memory demands due to explosive tree growth. Dbarts documentation even explicitly recommends power > 1:

power: A vector of real numbers greater than one, setting the BART hyperparameter for
the tree prior’s growth probability, given by base/(1 + depth)^power.

source pg. 28 under xbart section:
dbarts.pdf

Reproducible example

library(tidyverse)

dials::prior_terminal_node_expo()
#> Terminal Node Prior Exponent (quantitative)
#> Range: [0, 3]

n_folds <- 10
grid <- dials::grid_latin_hypercube(size = n_folds,
                                    dials::prior_outcome_range(), # k
                                    dials::prior_terminal_node_expo(), # power
                                    dials::prior_terminal_node_coef(), # base
                                    dials::trees(range = c(25, 300))) |>
  dplyr::rename(k = prior_outcome_range, power = prior_terminal_node_expo, base = prior_terminal_node_coef, n.trees = trees)
grid

 A tibble: 10 × 4
#>       k  power   base n.trees
#>   <dbl>  <dbl>  <dbl>   <int>
#> 1 2.55  0.0803 0.964       61
#> 2 4.47  1.26   0.132      233
#> 3 1.40  1.78   0.863      182
#> 4 1.84  2.66   0.525      255
#> 5 0.194 1.02   0.602      199
#> 6 2.22  1.89   0.766      120
#> 7 4.98  0.387  0.0652     294
#> 8 3.93  2.88   0.212       87
#> 9 0.761 0.724  0.438       39
#> 10 3.38  2.23   0.329      152

# Let's see what the tree depth growth curves look like.
max_depth = 20
depth <- matrix(1:max_depth, nrow = nrow(grid), ncol = max_depth, byrow = T)
depth <- apply(depth, 2, function(v) grid$base / ((1 + v)^grid$power))
grid |> dplyr::bind_cols(depth, .name_repair = ~ c(names(grid),paste0("X", 1:max_depth))) |>
  dplyr::mutate(r = 1:dplyr::n()) |>
  tidyr::pivot_longer(starts_with("X"), names_prefix = "X", values_to = "growth probability", names_to = "depth", names_transform = list(depth = as.integer)) |>
  ggplot2::ggplot(aes(x = depth, y=`growth probability`, col=as.factor(r), group = r)) +
 geom_line()
#> Warning: Removed 5 row(s) containing missing values (geom_path).

tree_growth_test

Created on 2022-09-06 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22)
#>  os       macOS Monterey 12.5.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       America/Los_Angeles
#>  date     2022-09-06
#>  pandoc   2.18 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package       * version date (UTC) lib source
#>  P assertthat      0.2.1   2019-03-21 [?] CRAN (R 4.2.0)
#>  P backports       1.4.1   2021-12-13 [?] CRAN (R 4.2.0)
#>    broom           1.0.1   2022-08-29 [1] CRAN (R 4.2.0)
#>  P cellranger      1.1.0   2016-07-27 [?] CRAN (R 4.2.0)
#>  P cli             3.3.0   2022-04-25 [?] CRAN (R 4.2.0)
#>  P colorspace      2.0-3   2022-02-21 [?] CRAN (R 4.2.0)
#>  P crayon          1.5.1   2022-03-26 [?] CRAN (R 4.2.0)
#>  P curl            4.3.2   2021-06-23 [?] CRAN (R 4.2.0)
#>  P DBI             1.1.3   2022-06-18 [?] CRAN (R 4.2.0)
#>    dbplyr          2.2.1   2022-06-27 [1] CRAN (R 4.2.0)
#>    dials           1.0.0   2022-06-14 [1] CRAN (R 4.2.0)
#>  P DiceDesign      1.9     2021-02-13 [?] CRAN (R 4.2.0)
#>  P digest          0.6.29  2021-12-01 [?] CRAN (R 4.2.0)
#>    dplyr         * 1.0.9   2022-04-28 [1] CRAN (R 4.2.0)
#>  P ellipsis        0.3.2   2021-04-29 [?] CRAN (R 4.2.0)
#>  P evaluate        0.15    2022-02-18 [?] CRAN (R 4.2.0)
#>  P fansi           1.0.3   2022-03-24 [?] CRAN (R 4.2.0)
#>  P farver          2.1.0   2021-02-28 [?] CRAN (R 4.2.0)
#>  P fastmap         1.1.0   2021-01-25 [?] CRAN (R 4.2.0)
#>  P forcats       * 0.5.1   2021-01-27 [?] CRAN (R 4.2.0)
#>  P fs              1.5.2   2021-12-08 [?] CRAN (R 4.2.0)
#>  P gargle          1.2.0   2021-07-02 [?] CRAN (R 4.2.0)
#>    generics        0.1.3   2022-07-05 [1] CRAN (R 4.2.0)
#>    ggplot2       * 3.3.6   2022-05-03 [1] CRAN (R 4.2.0)
#>  P glue            1.6.2   2022-02-24 [?] CRAN (R 4.2.0)
#>  P googledrive     2.0.0   2021-07-08 [?] CRAN (R 4.2.0)
#>  P googlesheets4   1.0.0   2021-07-21 [?] CRAN (R 4.2.0)
#>  P gtable          0.3.0   2019-03-25 [?] CRAN (R 4.2.0)
#>    hardhat         1.2.0   2022-06-30 [1] CRAN (R 4.2.0)
#>    haven           2.5.1   2022-08-22 [1] CRAN (R 4.2.0)
#>  P highr           0.9     2021-04-16 [?] CRAN (R 4.2.0)
#>  P hms             1.1.1   2021-09-26 [?] CRAN (R 4.2.0)
#>  P htmltools       0.5.2   2021-08-25 [?] CRAN (R 4.2.0)
#>    httr            1.4.3   2022-05-04 [1] CRAN (R 4.2.0)
#>  P jsonlite        1.8.0   2022-02-22 [?] CRAN (R 4.2.0)
#>  P knitr           1.38    2022-03-25 [?] RSPM (R 4.2.0)
#>  P labeling        0.4.2   2020-10-20 [?] CRAN (R 4.2.0)
#>  P lifecycle       1.0.1   2021-09-24 [?] CRAN (R 4.2.0)
#>  P lubridate       1.8.0   2021-10-07 [?] CRAN (R 4.2.0)
#>  P magrittr        2.0.3   2022-03-30 [?] CRAN (R 4.2.0)
#>  P mime            0.12    2021-09-28 [?] CRAN (R 4.2.0)
#>  P modelr          0.1.8   2020-05-19 [?] CRAN (R 4.2.0)
#>  P munsell         0.5.0   2018-06-12 [?] CRAN (R 4.2.0)
#>    pillar          1.8.0   2022-07-18 [1] CRAN (R 4.2.0)
#>  P pkgconfig       2.0.3   2019-09-22 [?] CRAN (R 4.2.0)
#>  P purrr         * 0.3.4   2020-04-17 [?] CRAN (R 4.2.0)
#>  P R.cache         0.16.0  2022-07-21 [?] CRAN (R 4.2.0)
#>    R.methodsS3     1.8.2   2022-06-13 [1] CRAN (R 4.2.0)
#>    R.oo            1.25.0  2022-06-12 [1] CRAN (R 4.2.0)
#>    R.utils         2.12.0  2022-06-28 [1] CRAN (R 4.2.0)
#>  P R6              2.5.1   2021-08-19 [?] CRAN (R 4.2.0)
#>    readr         * 2.1.2   2022-01-30 [1] CRAN (R 4.2.0)
#>  P readxl          1.4.0   2022-03-28 [?] CRAN (R 4.2.0)
#>  P reprex          2.0.1   2021-08-05 [?] CRAN (R 4.2.0)
#>    rlang           1.0.4   2022-07-12 [1] CRAN (R 4.2.0)
#>  P rmarkdown       2.11    2021-09-14 [?] RSPM (R 4.2.0)
#>  P rstudioapi      0.13    2020-11-12 [?] CRAN (R 4.2.0)
#>  P rvest           1.0.2   2021-10-16 [?] CRAN (R 4.2.0)
#>  P scales          1.2.0   2022-04-13 [?] CRAN (R 4.2.0)
#>  P sessioninfo     1.2.2   2021-12-06 [?] CRAN (R 4.2.0)
#>    stringi         1.7.8   2022-07-11 [1] CRAN (R 4.2.0)
#>  P stringr       * 1.4.0   2019-02-10 [?] CRAN (R 4.2.0)
#>  P styler          1.7.0   2022-03-13 [?] CRAN (R 4.2.0)
#>    tibble        * 3.1.8   2022-07-22 [1] CRAN (R 4.2.0)
#>  P tidyr         * 1.2.0   2022-02-01 [?] CRAN (R 4.2.0)
#>  P tidyselect      1.1.2   2022-02-21 [?] CRAN (R 4.2.0)
#>    tidyverse     * 1.3.2   2022-07-18 [1] CRAN (R 4.2.0)
#>    tzdb            0.3.0   2022-03-28 [1] CRAN (R 4.2.0)
#>  P utf8            1.2.2   2021-07-24 [?] CRAN (R 4.2.0)
#>  P vctrs           0.4.1   2022-04-13 [?] CRAN (R 4.2.0)
#>  P withr           2.5.0   2022-03-03 [?] CRAN (R 4.2.0)
#>    xfun            0.32    2022-08-10 [1] CRAN (R 4.2.0)
#>  P xml2            1.3.3   2021-11-30 [?] CRAN (R 4.2.0)
#>  P yaml            2.3.5   2022-02-21 [?] CRAN (R 4.2.0)
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions