Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with device = "mps" on Sonoma OS #144

Open
cgoo4 opened this issue Jan 12, 2024 · 3 comments
Open

Issues with device = "mps" on Sonoma OS #144

cgoo4 opened this issue Jan 12, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@cgoo4
Copy link

cgoo4 commented Jan 12, 2024

Running the example below in a fresh R session, tabnet_pretrain() works with device = "mps", but tabnet_fit() hangs (no message) and I need to Terminate R to recover. Session info attached.

library(tabnet)
library(tidymodels)
library(modeldata)

set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id")), new_role = "predictor")

set.seed(1)

tab_pre <- tab_rec |> 
  tabnet_pretrain(train, device = "mps", checkpoint_epochs = 2)

tab_fit <- tab_rec |>
  tabnet_fit(train, tabnet_model = tab_pre, from_epoch = 2, device = "cpu") # hangs with "mps"

test |> bind_cols(predict(tab_fit, test))
#> # A tibble: 2,465 × 24
#>    funded_amnt term    int_rate sub_grade addr_state verification_status
#>          <int> <fct>      <dbl> <fct>     <fct>      <fct>              
#>  1       10000 term_36    11.5  B5        TX         Source_Verified    
#>  2        7000 term_36    13.0  C2        CA         Source_Verified    
#>  3       35000 term_36    11.5  B5        TN         Source_Verified    
#>  4       15000 term_36    10.8  B4        TX         Not_Verified       
#>  5       27200 term_60    10.8  B4        NC         Not_Verified       
#>  6       12000 term_36    14.5  C4        OR         Source_Verified    
#>  7       15025 term_36    13.7  C3        MA         Source_Verified    
#>  8       20000 term_36     5.32 A1        WI         Not_Verified       
#>  9       20000 term_36    12.0  C1        VA         Verified           
#> 10       10000 term_36    10.8  B4        NC         Verified           
#> # ℹ 2,455 more rows
#> # ℹ 18 more variables: annual_inc <dbl>, emp_length <fct>, delinq_2yrs <int>,
#> #   inq_last_6mths <int>, revol_util <dbl>, acc_now_delinq <int>,
#> #   open_il_6m <int>, open_il_12m <int>, open_il_24m <int>, total_bal_il <int>,
#> #   all_util <int>, inq_fi <int>, inq_last_12m <int>, delinq_amnt <int>,
#> #   num_il_tl <int>, total_il_high_credit_limit <int>, Class <fct>,
#> #   .pred_class <fct>

Created on 2024-01-12 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.3.2 (2023-10-31)
#>  os       macOS Sonoma 14.2.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/London
#>  date     2024-01-12
#>  pandoc   3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.4.1      2021-12-13 [2] CRAN (R 4.3.0)
#>  bit            4.0.5      2022-11-15 [2] CRAN (R 4.3.0)
#>  bit64          4.0.5      2020-08-30 [2] CRAN (R 4.3.0)
#>  broom        * 1.0.5      2023-06-09 [2] CRAN (R 4.3.0)
#>  callr          3.7.3      2022-11-02 [2] CRAN (R 4.3.0)
#>  class          7.3-22     2023-05-03 [2] CRAN (R 4.3.2)
#>  cli            3.6.2      2023-12-11 [1] CRAN (R 4.3.1)
#>  codetools      0.2-19     2023-02-01 [2] CRAN (R 4.3.2)
#>  colorspace     2.1-0      2023-01-23 [1] CRAN (R 4.3.0)
#>  coro           1.0.3      2022-07-19 [2] CRAN (R 4.3.0)
#>  data.table     1.14.10    2023-12-08 [1] CRAN (R 4.3.1)
#>  dials        * 1.2.0      2023-04-03 [1] CRAN (R 4.3.0)
#>  DiceDesign     1.10       2023-12-07 [1] CRAN (R 4.3.1)
#>  digest         0.6.33     2023-07-07 [1] CRAN (R 4.3.0)
#>  dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.3.1)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.3.0)
#>  evaluate       0.23       2023-11-01 [2] CRAN (R 4.3.1)
#>  fansi          1.0.6      2023-12-08 [1] CRAN (R 4.3.1)
#>  fastmap        1.1.1      2023-02-24 [2] CRAN (R 4.3.0)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.3.0)
#>  fs             1.6.3      2023-07-20 [2] CRAN (R 4.3.0)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.3.0)
#>  future         1.33.1     2023-12-22 [1] CRAN (R 4.3.1)
#>  future.apply   1.11.1     2023-12-21 [1] CRAN (R 4.3.1)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.3.0)
#>  ggplot2      * 3.4.4      2023-10-12 [1] CRAN (R 4.3.1)
#>  globals        0.16.2     2022-11-21 [1] CRAN (R 4.3.0)
#>  glue           1.7.0      2024-01-09 [1] CRAN (R 4.3.1)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.3.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.3.0)
#>  gtable         0.3.4      2023-08-21 [1] CRAN (R 4.3.0)
#>  hardhat        1.3.0      2023-03-30 [1] CRAN (R 4.3.0)
#>  htmltools      0.5.7      2023-11-03 [2] CRAN (R 4.3.1)
#>  infer        * 1.0.5      2023-09-06 [2] CRAN (R 4.3.0)
#>  ipred          0.9-14     2023-03-09 [1] CRAN (R 4.3.0)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.3.0)
#>  jsonlite       1.8.8      2023-12-04 [2] CRAN (R 4.3.1)
#>  knitr          1.45       2023-10-30 [2] CRAN (R 4.3.1)
#>  lattice        0.22-5     2023-10-24 [2] CRAN (R 4.3.1)
#>  lava           1.7.3      2023-11-04 [1] CRAN (R 4.3.1)
#>  lhs            1.1.6      2022-12-17 [1] CRAN (R 4.3.0)
#>  lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.3.1)
#>  listenv        0.9.0      2022-12-16 [1] CRAN (R 4.3.0)
#>  lubridate      1.9.3      2023-09-27 [1] CRAN (R 4.3.1)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.3.0)
#>  MASS           7.3-60     2023-05-04 [2] CRAN (R 4.3.2)
#>  Matrix         1.6-4      2023-11-30 [2] CRAN (R 4.3.1)
#>  modeldata    * 1.2.0      2023-08-09 [2] CRAN (R 4.3.0)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.3.0)
#>  nnet           7.3-19     2023-05-03 [2] CRAN (R 4.3.2)
#>  parallelly     1.36.0     2023-05-26 [1] CRAN (R 4.3.0)
#>  parsnip      * 1.1.1      2023-08-17 [1] CRAN (R 4.3.0)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.3.0)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.3.0)
#>  processx       3.8.3      2023-12-10 [2] CRAN (R 4.3.1)
#>  prodlim        2023.08.28 2023-08-28 [1] CRAN (R 4.3.0)
#>  ps             1.7.5      2023-04-18 [2] CRAN (R 4.3.0)
#>  purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.3.0)
#>  R.cache        0.16.0     2022-07-21 [2] CRAN (R 4.3.0)
#>  R.methodsS3    1.8.2      2022-06-13 [2] CRAN (R 4.3.0)
#>  R.oo           1.25.0     2022-06-12 [2] CRAN (R 4.3.0)
#>  R.utils        2.12.3     2023-11-18 [2] CRAN (R 4.3.1)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.3.0)
#>  Rcpp           1.0.12     2024-01-09 [1] CRAN (R 4.3.1)
#>  recipes      * 1.0.9      2023-12-13 [1] CRAN (R 4.3.1)
#>  reprex         2.0.2      2022-08-17 [2] CRAN (R 4.3.0)
#>  rlang          1.1.3      2024-01-10 [1] CRAN (R 4.3.1)
#>  rmarkdown      2.25       2023-09-18 [2] CRAN (R 4.3.1)
#>  rpart          4.1.23     2023-12-05 [2] CRAN (R 4.3.1)
#>  rsample      * 1.2.0      2023-08-23 [1] CRAN (R 4.3.0)
#>  rstudioapi     0.15.0     2023-07-07 [2] CRAN (R 4.3.0)
#>  safetensors    0.1.2      2023-09-12 [2] CRAN (R 4.3.0)
#>  scales       * 1.2.1      2022-08-20 [1] CRAN (R 4.3.2)
#>  sessioninfo    1.2.2      2021-12-06 [2] CRAN (R 4.3.0)
#>  styler         1.10.2     2023-08-29 [2] CRAN (R 4.3.0)
#>  survival       3.5-7      2023-08-14 [2] CRAN (R 4.3.2)
#>  tabnet       * 0.5.0.9000 2024-01-11 [1] Github (mlverse/tabnet@962bafa)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.3.0)
#>  tidymodels   * 1.1.1      2023-08-24 [2] CRAN (R 4.3.0)
#>  tidyr        * 1.3.0      2023-01-24 [1] CRAN (R 4.3.0)
#>  tidyselect     1.2.0      2022-10-10 [1] CRAN (R 4.3.0)
#>  timechange     0.2.0      2023-01-11 [1] CRAN (R 4.3.0)
#>  timeDate       4032.109   2023-12-14 [1] CRAN (R 4.3.1)
#>  torch          0.12.0     2024-01-05 [1] Github (mlverse/torch@23071c1)
#>  tune         * 1.1.2      2023-08-23 [1] CRAN (R 4.3.0)
#>  utf8           1.2.4      2023-10-22 [1] CRAN (R 4.3.1)
#>  vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.3.1)
#>  withr          2.5.2      2023-10-30 [1] CRAN (R 4.3.1)
#>  workflows    * 1.1.3      2023-02-22 [1] CRAN (R 4.3.0)
#>  workflowsets * 1.0.1      2023-04-06 [2] CRAN (R 4.3.0)
#>  xfun           0.41       2023-11-01 [2] CRAN (R 4.3.1)
#>  yaml           2.3.8      2023-12-11 [2] CRAN (R 4.3.1)
#>  yardstick    * 1.2.0      2023-04-21 [1] CRAN (R 4.3.0)
#>  zeallot        0.1.0      2018-01-28 [2] CRAN (R 4.3.0)
#> 
#>  [1] /Users/carlgoodwin/Library/R/arm64/4.3/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────
@cregouby
Copy link
Collaborator

Hello @cgoo4

It smells like a silent OOM of the MPS device.
Running it from the terminal or as a test case, you may eventually see

! MPS backend out of memory (MPS allocated: 0 bytes, other allocations: 0 bytes, max allowed: 1.70 GB). Tried to allocate 0 bytes on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).
Exception raised from alloc_buffer_block at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSAllocator.mm:235 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >) + 81 (0x105c22ca1 in libc10.dylib)

I would recommend you to drastically lower the virtual_batch_size = and eventually batch_size = in your tabnet_config()

Let me know if it helps

@cgoo4
Copy link
Author

cgoo4 commented Jan 13, 2024

Hi @cregouby - I've tried progressively reducing virtual_batch_size to 16, but tabnet_fit still hangs. tabnet_pretrain runs using the defaults. Both produce the message.

I'm using an M2 Max with 64 GB memory.

image

@cregouby cregouby added the bug Something isn't working label Jan 14, 2024
@cregouby
Copy link
Collaborator

cregouby commented Jan 14, 2024

Hello @cgoo4
The problem is not linked to using tabnet_model = as the second run of whatever training on device="mps" hits the issue.

There are a lot of opened issues currently on device="mps" on sonoma OS (tensorflow, pytorch, ...) so I guess there is nothing we can do until the apple dev fix the MPS code.

@cregouby cregouby changed the title Issues with device = "mps" Issues with device = "mps" on Sonoma OS Jan 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants