Skip to content

A comprehensive toolkit and benchmark for tabular data learning, featuring over 20 deep methods, more than 10 classical methods, and 300 diverse tabular datasets.

License

Notifications You must be signed in to change notification settings

qile2000/LAMDA-TALENT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

93 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation


PyTorch - Version Python - Version

TALENT: A Tabular Analytics and Learning Toolbox

[Paper] [Code]


🎉 Introduction

Welcome to TALENT, a benchmark with a comprehensive machine learning toolbox designed to enhance model performance on tabular data. TALENT integrates advanced deep learning models, classical algorithms, and efficient hyperparameter tuning, offering robust preprocessing capabilities to optimize learning from tabular datasets. The toolbox is user-friendly and adaptable, catering to both novice and expert data scientists.

TALENT offers the following advantages:

  • Diverse Methods: Includes various classical methods, tree-based methods, and the latest popular deep learning methods.
  • Extensive Dataset Collection: Equipped with 300 datasets, covering a wide range of task types, size distributions, and dataset domains.
  • Customizability: Easily allows the addition of datasets and methods.
  • Versatile Support: Supports diverse normalization, encoding, and metrics.

📚Citing TALENT

If you use any content of this repo for your work, please cite the following bib entry:

TODO

🌟 Methods

TALENT integrates an extensive array of 20+ deep learning architectures for tabular data, including but not limited to:

  • MLP: A multi-layer neural network, which is implemented according to RTDL.
  • ResNet: A DNN that uses skip connections across many layers, which is implemented according to RTDL.
  • SNN: An MLP-like architecture utilizing the SELU activation, which facilitates the training of deeper neural networks.
  • DANets: A neural network designed to enhance tabular data processing by grouping correlated features and reducing computational complexity.
  • TabCaps: A capsule network that encapsulates all feature values of a record into vectorial features.
  • DCNv2: Consists of an MLP-like module combined with a feature crossing module, which includes both linear layers and multiplications.
  • NODE: A tree-mimic method that generalizes oblivious decision trees, combining gradient-based optimization with hierarchical representation learning.
  • GrowNet: A gradient boosting framework that uses shallow neural networks as weak learners.
  • TabNet: A tree-mimic method using sequential attention for feature selection, offering interpretability and self-supervised learning capabilities.
  • TabR: A deep learning model that integrates a KNN component to enhance tabular data predictions through an efficient attention-like mechanism.
  • ModernNCA: A deep tabular model inspired by traditional Neighbor Component Analysis, which makes predictions based on the relationships with neighbors in a learned embedding space.
  • DNNR: Enhances KNN by using local gradients and Taylor approximations for more accurate and interpretable predictions.
  • AutoInt: A token-based method that uses a multi-head self-attentive neural network to automatically learn high-order feature interactions.
  • Saint: A token-based method that leverages row and column attention mechanisms for tabular data.
  • TabTransformer: A token-based method that enhances tabular data modeling by transforming categorical features into contextual embeddings.
  • FT-Transformer: A token-based method which transforms features to embeddings and applies a series of attention-based transformations to the embeddings.
  • TANGOS: A regularization-based method for tabular data that uses gradient attributions to encourage neuron specialization and orthogonalization.
  • SwitchTab: A self-supervised method tailored for tabular data that improves representation learning through an asymmetric encoder-decoder framework.
  • PTaRL: A regularization-based framework that enhances prediction by constructing and projecting into a prototype-based space.
  • TabPFN: A general model which involves the use of pre-trained deep neural networks that can be directly applied to any tabular task.
  • HyperFast: A meta-trained hypernetwork that generates task-specific neural networks for instant classification of tabular data.
  • TabPTM: A general method for tabular data that standardizes heterogeneous datasets using meta-representations, allowing a pre-trained model to generalize to unseen datasets without additional training.

☄️ How to Use TALENT

🕹️ Clone

Clone this GitHub repository:

git clone https://github.com/qile2000/LAMDA-TALENT
cd LAMDA-TALENT/LAMDA-TALENT

🔑 Run experiment

  1. Edit the configs/default/[MODEL_NAME].json and config/opt_space/[MODEL_NAME].json for global settings and hyperparameters.

  2. Run:

    python train_model_deep.py --model_type MODEL_NAME

    for deep methods, or:

    python train_model_classical.py --model_type MODEL_NAME

    for classical methods.

🛠️How to Add New Methods

For methods like the MLP class that only need to design the model, you only need to:

  • Add the model class in model/models.
  • Inherit from model/methods/base.py and override the construct_model() method in the new class.
  • Add the method name in the get_method function in model/utils.py.
  • Add the parameter settings for the new method in configs/default/[MODEL_NAME].json and configs/opt_space/[MODEL_NAME].json.

For other methods that require changing the training process, partially override functions based on model/methods/base.py. For details, refer to the implementation of other methods in model/methods/.

📦 Dependencies

  1. torch
  2. scikit-learn
  3. pandas
  4. tqdm
  5. numpy
  6. scipy

🗂️ Benchmark Datasets

Datasets are available at Google Drive.

📂How to Place Datasets

Datasets are placed in the project's current directory, corresponding to the file name specified by args.dataset_path. For instance, if the project is LAMDA-TALENT, the data should be placed in LAMDA-TALENT/args.dataset_path/args.dataset.

Each dataset folder args.dataset consists of:

  • Numeric features: N_train/val/test.npy (can be omitted if there are no numeric features)

  • Categorical features: C_train/val/test.npy (can be omitted if there are no categorical features)

  • Labels: y_train/val/test.npy

  • info.json, which must include the following three contents (task_type can be "regression", "multiclass" or "binclass"):

    {
      "task_type": "regression", 
      "n_num_features": 10,
      "n_cat_features": 10
    }

📝 Experimental Results

We provide comprehensive evaluations of classical and deep tabular methods based on our toolbox in a fair manner in the Figure. Three tabular prediction tasks, namely, binary classification, multi-class classification, and regression, are considered, and each subfigure represents a different task type.

We use Accuracy and RMSE as the metrics for classification tasks and regression tasks, respectively. To calibrate the metrics, we choose the average performance rank to compare all methods, where a lower rank indicates better performance, following Sheskin (2003). Efficiency is calculated by the average training time in seconds, with lower values denoting better time efficiency. The model size is visually indicated by the radius of the circles, offering a quick glance at the trade-off between model complexity and performance.

  • Binary classification

  • Multiclass Classification

  • Regression

  • All tasks

From the comparison, we observe that CatBoost achieves the best average rank in most classification and regression tasks. Among all deep tabular methods, ModernNCA performs the best in most cases while maintaining an acceptable training cost. These results highlight the effectiveness of CatBoost and ModernNCA in handling various tabular prediction tasks, making them suitable choices for practitioners seeking high performance and efficiency.

These visualizations serve as an effective tool for quickly and fairly assessing the strengths and weaknesses of various tabular methods across different task types, enabling researchers and practitioners to make informed decisions when selecting suitable modeling techniques for their specific needs.

👨‍🏫 Acknowledgments

We thank the following repos for providing helpful components/functions in our work:

🤗 Contact

If there are any questions, please feel free to propose new features by opening an issue or contact the author: Siyang Liu ([email protected]) and Haorun Cai ([email protected]) and Qile Zhou ([email protected]) and Han-Jia Ye ([email protected]). Enjoy the code.

🚀 Star History

Star History Chart

Thanks LAMDA-PILOT and LAMDA-ZhiJian for the template.

About

A comprehensive toolkit and benchmark for tabular data learning, featuring over 20 deep methods, more than 10 classical methods, and 300 diverse tabular datasets.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published