Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue with noether p-value #62

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
<p align="center">
<!-- <img src="https://github.com/Novartis/torchsurv/blob/main/docs/source/logo_firecamp.png" width="300"> -->
<img src="./docs/source/logo_firecamp.png" width="300">

</p>

# Deep survival analysis made easy

[![Python](https://img.shields.io/pypi/pyversions/torchsurv?label=Python)](https://pypi.org/project/torchsurv/)
Expand Down Expand Up @@ -51,13 +57,11 @@ cindex.p_value(method="noether", alternative="two_sided")
cindex.compare(cindexB)
```


## Installation and dependencies


First, install the package using either [PyPI]([https://pypi.org/](https://pypi.org/project/torchsurv/)) or [Conda]([https://anaconda.org/anaconda/conda](https://anaconda.org/conda-forge/torchsurv))

- Using conda (`recommended`)
- Using conda (**recommended**)
```bash
conda install conda-forge::torchsurv
```
Expand Down
121 changes: 68 additions & 53 deletions docs/notebooks/introduction.ipynb

Large diffs are not rendered by default.

178 changes: 44 additions & 134 deletions docs/notebooks/momentum.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"\n",
"### Dependencies\n",
"\n",
"To run this notebooks, dependencies must be installed. the recommended method is to use our development conda environment (`preferred`). Instruction can be found [here](https://opensource.nibr.com/torchsurv/devnotes.html#set-up-a-development-environment-via-conda) to install all optional dependencies. The other method is to install only required packages using the command line below:\n"
"To run this notebooks, dependencies must be installed. the recommended method is to use our development conda environment (**preferred**). Instruction can be found [here](https://opensource.nibr.com/torchsurv/devnotes.html#set-up-a-development-environment-via-conda) to install all optional dependencies. The other method is to install only required packages using the command line below:\n"
]
},
{
Expand Down Expand Up @@ -104,11 +104,34 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "ebaf967b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CUDA-enabled GPU/TPU is available.\n"
]
}
],
"source": [
"# Detect available accelerator; Downgrade batch size if only CPU available\n",
"if any([torch.cuda.is_available(), torch.backends.mps.is_available()]):\n",
" print(\"CUDA-enabled GPU/TPU is available.\")\n",
" BATCH_SIZE = 500 # batch size for training\n",
"else:\n",
" print(\"No CUDA-enabled GPU found, using CPU.\")\n",
" BATCH_SIZE = 50 # batch size for training"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "794004c5-588c-4590-ae96-c6d9e52109ff",
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 500 # batch size for training\n",
"EPOCHS = 2 # number of epochs to train\n",
"FAST_DEV_RUN = None # Quick prototype, set to None for full training"
]
Expand All @@ -135,7 +158,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "4abbc6b0",
"metadata": {},
"outputs": [],
Expand All @@ -153,7 +176,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "ebf5caff",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -201,7 +224,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "c216fa33-de09-4be2-82cc-83cb73db3a42",
"metadata": {},
"outputs": [],
Expand All @@ -217,7 +240,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "8056a675-fbce-4f4b-86c0-ab7dd924e4b1",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -255,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "1e7a2c7e-a1ef-42fa-ba74-1d33a1dcf2f3",
"metadata": {},
"outputs": [],
Expand All @@ -266,7 +289,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "3f577acf-a821-41a4-8544-318617755d1e",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -296,7 +319,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "430079cc-4fad-4da2-8ea5-aa904c41ec0e",
"metadata": {},
"outputs": [
Expand All @@ -319,21 +342,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████| 11/11 [01:02<00:00, 0.18it/s, loss_step=218.0, val_loss_step=282.0, cindex_step=0.652, val_loss_epoch=287.0, cindex_epoch=0.665, loss_epoch=228.0]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=2` reached.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████| 11/11 [01:02<00:00, 0.18it/s, loss_step=218.0, val_loss_step=282.0, cindex_step=0.652, val_loss_epoch=287.0, cindex_epoch=0.665, loss_epoch=228.0]\n"
"Epoch 0: 0%| | 0/11 [00:00<?, ?it/s] "
]
}
],
Expand All @@ -344,34 +353,10 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "7854deb3-52f8-4a92-b38f-ff304bf82a34",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing DataLoader 0: 100%|██████████| 20/20 [01:22<00:00, 0.24it/s]\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" Test metric DataLoader 0\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" cindex_epoch 0.6686268448829651\n",
" val_loss_epoch -458.2862548828125\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
]
},
{
"data": {
"text/plain": [
"[{'val_loss_epoch': -458.2862548828125, 'cindex_epoch': 0.6686268448829651}]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Test the model\n",
"trainer.test(model_regular, datamodule)"
Expand All @@ -391,7 +376,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"id": "bab50a3f-5670-4264-b2c0-4eccb5f48624",
"metadata": {},
"outputs": [],
Expand All @@ -408,51 +393,10 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "00473ec0-9f44-47f2-824d-02dcc92dba7d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (mps), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"\n",
" | Name | Type | Params\n",
"-----------------------------------\n",
"0 | model | Momentum | 22.3 M\n",
"-----------------------------------\n",
"11.2 M Trainable params\n",
"11.2 M Non-trainable params\n",
"22.3 M Total params\n",
"89.366 Total estimated model params size (MB)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████| 110/110 [01:18<00:00, 1.40it/s, loss_step=57.10, val_loss_step=63.80, cindex_step=0.848, val_loss_epoch=59.80, cindex_epoch=0.841, loss_epoch=58.10]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=2` reached.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████| 110/110 [01:18<00:00, 1.40it/s, loss_step=57.10, val_loss_step=63.80, cindex_step=0.848, val_loss_epoch=59.80, cindex_epoch=0.841, loss_epoch=58.10]\n"
]
}
],
"outputs": [],
"source": [
"# Define trainer\n",
"trainer = L.Trainer(\n",
Expand All @@ -470,34 +414,10 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"id": "6441c1ea-b87f-4ff7-92dd-a8d7abf8daa5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing DataLoader 0: 100%|██████████| 200/200 [01:38<00:00, 2.03it/s]\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" Test metric DataLoader 0\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" cindex_epoch 0.858147144317627\n",
" val_loss_epoch 72.23859405517578\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
]
},
{
"data": {
"text/plain": [
"[{'val_loss_epoch': 72.23859405517578, 'cindex_epoch': 0.858147144317627}]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Validate the model\n",
"trainer.test(model_momentum, datamodule_momentum)"
Expand All @@ -513,7 +433,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"id": "855bda61",
"metadata": {},
"outputs": [],
Expand All @@ -527,7 +447,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"id": "38b1f7d1",
"metadata": {},
"outputs": [],
Expand All @@ -553,20 +473,10 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"id": "112e2e5d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cindex (regular) = 0.6948477029800415\n",
"Cindex (momentum) = 0.8578558564186096\n",
"Compare (p-value) = 2.1650459203215178e-11\n"
]
}
],
"outputs": [],
"source": [
"print(f\"Cindex (regular) = {cindex1(log_hz1, torch.ones_like(y).bool(), y.float())}\")\n",
"print(f\"Cindex (momentum) = {cindex2(log_hz2, torch.ones_like(y).bool(), y.float())}\")\n",
Expand Down
2 changes: 1 addition & 1 deletion src/torchsurv/metrics/cindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def _p_value_noether(self, alternative, null_value: float = 0.5) -> torch.tensor
cindex_se = self._concordance_index_se()

# get p-value
if cindex_se > 0:
if cindex_se >= 0:
p = torch.distributions.normal.Normal(0, 1).cdf(
(self.cindex - null_value) / cindex_se
)
Expand Down
Loading