Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better data preparation #1

Closed
cregouby opened this issue Mar 7, 2022 · 2 comments
Closed

Better data preparation #1

cregouby opened this issue Mar 7, 2022 · 2 comments

Comments

@cregouby
Copy link

cregouby commented Mar 7, 2022

@JunaidMB,
May I suggest you the following improvements

  • tabnet is good at managing categorical predictors, so In your case I would turn the "_flag" vars into logicals, and the character predictors into factors.
  • tabnet now supports missing values for almost all of the tasks. So I would not remove the NAs samples.
  • there is no need to dummify categorical values as tabnet has a more powerfull embedding for that.
  • I recommand increasing the batch_size here. ( the new default is 1024^2 ) to improve convergence and lower training time.
strand_pat <- NHSRdatasets::stranded_data %>% 
  setNames(c("stranded_class", "age", "care_home_ref_flag", "medically_safe_flag", 
             "hcop_flag", "needs_mental_health_support_flag", "previous_care_in_last_12_month", "admit_date", "frail_descrip")) %>% 
  mutate(across(where(is.character),as.factor),
         admit_date = as.Date(admit_date, format = "%d/%m/%Y"),
         across(ends_with("flag"), as.logical)) 
...
## Define Recipe to be applied to the dataset
stranded_rec <- 
  recipe(stranded_class ~ ., data = train_data) %>% 
  # Make a day of week and month feature from admit date and remove raw admit date
  step_date(admit_date, features = c("dow", "month")) %>% 
  step_rm(admit_date) %>% 
  # Upsample minority (positive) class
  themis::step_upsample(stranded_class, over_ratio = as.numeric(upsample_ratio)) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_numeric_predictors())
@JunaidMB
Copy link
Owner

Thanks a lot @cregouby, I've implemented those improvements! Is there a place/ documentation where we can see best practices for implementing Tabnet? Your presentation here was excellent at showing how to get setup with Tabnet and basic commands but is there a place where we can know information like the tips you shared above?

I think many people will have my initial temptation of doing the exact same preprocessing for Tabnet that we might do for Random Forest or XGBoost for example. If you could share anything helpful, I'll include it in the references section in the README.

@cregouby
Copy link
Author

Got it ! I'll add a dedicated vignette in my todo list. But first, please measure if there is an improvement in your metric (or not) because no one knows without looking at the data...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants