Skip to content

Commit

Permalink
run notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
bkmartinjr committed Sep 17, 2024
1 parent 17f2260 commit e66b4c2
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -43,7 +43,7 @@
" CZI_Census_Homo_Sapiens_URL,\n",
" context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n",
")\n",
"obs_value_filter = \"tissue_general == 'lung' and is_primary_data == True\"\n",
"obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n",
"\n",
"with experiment.axis_query(\n",
" measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n",
Expand All @@ -57,7 +57,8 @@
" obs_column_names=[\"cell_type\"],\n",
" batch_size=128,\n",
" shuffle=True,\n",
" )\n"
" seed=12345,\n",
" )"
]
},
{
Expand All @@ -69,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -116,7 +117,7 @@
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" return optimizer\n"
" return optimizer"
]
},
{
Expand All @@ -128,11 +129,76 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params | Mode \n",
"-----------------------------------------------------\n",
"0 | linear | Linear | 726 K | train\n",
"1 | loss_fn | CrossEntropyLoss | 0 | train\n",
"-----------------------------------------------------\n",
"726 K Trainable params\n",
"0 Non-trainable params\n",
"726 K Total params\n",
"2.905 Total estimated model params size (MB)\n",
"2 Modules in train mode\n",
"0 Modules in eval mode\n",
"/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n",
"/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
"################################################################################\n",
"WARNING!\n",
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
"to learn more and leave feedback.\n",
"################################################################################\n",
"\n",
" deprecation_warning()\n",
"/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
"################################################################################\n",
"WARNING!\n",
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
"to learn more and leave feedback.\n",
"################################################################################\n",
"\n",
" deprecation_warning()\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 19: 100%|██████████| 118/118 [00:17<00:00, 6.87it/s, v_num=6, train_loss=1.680, train_accuracy=0.977]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=20` reached.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 19: 100%|██████████| 118/118 [00:17<00:00, 6.86it/s, v_num=6, train_loss=1.680, train_accuracy=0.977]\n"
]
}
],
"source": [
"dataloader = soma_ml.experiment_dataloader(experiment_dataset, num_workers=2, persistent_workers=True)\n",
"dataloader = soma_ml.experiment_dataloader(\n",
" experiment_dataset, num_workers=2, persistent_workers=True\n",
")\n",
"\n",
"# The size of the input dimension is the number of genes\n",
"input_dim = experiment_dataset.shape[1]\n",
Expand All @@ -147,16 +213,16 @@
"\n",
"# Define the PyTorch Lightning Trainer\n",
"trainer = pl.Trainer(\n",
" max_epochs=10,\n",
" max_epochs=20,\n",
" # accelerator=args.accelerator,\n",
" # strategy=\"ddp\",\n",
")\n",
"\n",
"# set precision\n",
"torch.set_float32_matmul_precision('medium')\n",
"torch.set_float32_matmul_precision(\"high\")\n",
"\n",
"# Train the model\n",
"trainer.fit(model, train_dataloaders=dataloader)\n"
"trainer.fit(model, train_dataloaders=dataloader)"
]
}
],
Expand Down
Loading

0 comments on commit e66b4c2

Please sign in to comment.