diff --git a/data/helios/v2_pastis/README.md b/data/helios/v2_pastis/README.md index a209ac97..8b300a1d 100644 --- a/data/helios/v2_pastis/README.md +++ b/data/helios/v2_pastis/README.md @@ -35,4 +35,8 @@ python -m rslp.main helios launch_finetune --helios_checkpoint_path /weka/dfive- # Helios with max pooling. python -m rslp.main helios launch_finetune --helios_checkpoint_path /weka/dfive-default/helios/checkpoints/joer/v0.1_base_latent_mim_space_time/step165000 --patch_size 8 --encoder_embedding_size 768 --image_name favyen/rslphelios2 --config_paths+=data/helios/v2_pastis/basecfg.yaml --config_paths+=data/helios/v2_pastis/basecfg_helios_ts.yaml --config_paths+=data/helios/v2_shared/helios_freeze_then_lowlr.yaml --config_paths+=data/helios/v2_shared/helios_ts_simple_maxpool_sentinel2.yaml --cluster+=ai2/ceres-cirrascale --cluster+=ai2/saturn-cirrascale --rslp_project 2025_06_06_helios_finetuning --experiment_id v2_pastis_ts_maxpool_helios_latent_mim_space_time + +# sp +python -m rslp.main common beaker_train --config_paths+=data/helios/v2_pastis/basecfg.yaml --config_paths+=data/helios/v2_pastis/basecfg_sp.yaml --config_paths+=data/helios/v2_shared/sp.yaml --project_id 2025_06_06_helios_finetuning --experiment_id v2_pastis_uni_sp --cluster+=ai2/ceres --cluster+=ai2/jupiter '--weka_mounts+={"bucket_name":"dfive-default","mount_path":"/weka/dfive-default"}' --image_name favyen/rslphelios13 +python -m rslp.main common beaker_train --config_paths+=data/helios/v2_pastis/basecfg.yaml --config_paths+=data/helios/v2_pastis/basecfg_sp_ts.yaml --config_paths+=data/helios/v2_shared/sp.yaml --project_id 2025_06_06_helios_finetuning --experiment_id v2_pastis_ts_sp --cluster+=ai2/ceres --cluster+=ai2/jupiter '--weka_mounts+={"bucket_name":"dfive-default","mount_path":"/weka/dfive-default"}' --image_name favyen/rslphelios13 ``` diff --git a/data/helios/v2_pastis/basecfg_sp.yaml b/data/helios/v2_pastis/basecfg_sp.yaml new file mode 100644 index 00000000..0fde7d38 --- /dev/null +++ b/data/helios/v2_pastis/basecfg_sp.yaml @@ -0,0 +1,62 @@ +model: + init_args: + model: + init_args: + encoder: + - class_path: rslp.swin_pretrain.model.Model + decoders: + segment: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 20 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead +data: + init_args: + inputs: + image: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + targets: + data_type: "raster" + layers: ["label"] + bands: ["class"] + is_target: true + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + # PASTIS is missing B01 and B09. + # We use B02 to fill in B01 and B8A to fill in B09. + image: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + train_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + # PASTIS is missing B01 and B09. + # We use B02 to fill in B01 and B8A to fill in B09. + image: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image", "target/segment/classes", "target/segment/valid"] diff --git a/data/helios/v2_pastis/basecfg_sp_ts.yaml b/data/helios/v2_pastis/basecfg_sp_ts.yaml new file mode 100644 index 00000000..323dec15 --- /dev/null +++ b/data/helios/v2_pastis/basecfg_sp_ts.yaml @@ -0,0 +1,139 @@ +model: + init_args: + model: + init_args: + encoder: + - class_path: rslp.swin_pretrain.model.Model + decoders: + segment: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 20 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead +data: + init_args: + inputs: + sentinel2_0: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_1: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_2: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_3: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_4: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_5: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_6: + data_type: "raster" + layers: ["sentinel2.6"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_7: + data_type: "raster" + layers: ["sentinel2.7"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_8: + data_type: "raster" + layers: ["sentinel2.8"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_9: + data_type: "raster" + layers: ["sentinel2.9"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_10: + data_type: "raster" + layers: ["sentinel2.10"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_11: + data_type: "raster" + layers: ["sentinel2.11"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + targets: + data_type: "raster" + layers: ["label"] + bands: ["class"] + is_target: true + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + # PASTIS is missing B01 and B09. + # We use B02 to fill in B01 and B8A to fill in B09. + sentinel2_0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_2: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_3: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_4: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_5: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_6: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_7: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_8: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_9: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_10: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_11: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + train_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + # PASTIS is missing B01 and B09. + # We use B02 to fill in B01 and B8A to fill in B09. + sentinel2_0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_2: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_3: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_4: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_5: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_6: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_7: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_8: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_9: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_10: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_11: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image", "target/segment/classes", "target/segment/valid"] diff --git a/data/helios/v2_satlas_marine_infra_128/README.md b/data/helios/v2_satlas_marine_infra_128/README.md index 63c1eeba..cf5388e9 100644 --- a/data/helios/v2_satlas_marine_infra_128/README.md +++ b/data/helios/v2_satlas_marine_infra_128/README.md @@ -35,4 +35,8 @@ python -m rslp.main helios launch_finetune --helios_checkpoint_path /weka/dfive- # Helios with max pooling. python -m rslp.main helios launch_finetune --helios_checkpoint_path /weka/dfive-default/helios/checkpoints/joer/v0.1_base_latent_mim_space_time/step165000 --patch_size 8 --encoder_embedding_size 768 --image_name favyen/rslphelios2 --config_paths+=data/helios/v2_satlas_marine_infra_128/basecfg.yaml --config_paths+=data/helios/v2_satlas_marine_infra_128/basecfg_helios_ts.yaml --config_paths+=data/helios/v2_shared/helios_freeze_then_lowlr.yaml --config_paths+=data/helios/v2_shared/helios_ts_simple_maxpool_sentinel2.yaml --cluster+=ai2/ceres-cirrascale --cluster+=ai2/saturn-cirrascale --rslp_project 2025_06_06_helios_finetuning --experiment_id v2_satlas_marine_infra_128_ts_maxpool_helios_latent_mim_space_time + +# sp +python -m rslp.main common beaker_train --config_paths+=data/helios/v2_satlas_marine_infra_128/basecfg.yaml --config_paths+=data/helios/v2_satlas_marine_infra_128/basecfg_sp.yaml --config_paths+=data/helios/v2_shared/sp.yaml --project_id 2025_06_06_helios_finetuning --experiment_id v2_satlas_marine_infra_128_uni_sp --cluster+=ai2/ceres --cluster+=ai2/jupiter '--weka_mounts+={"bucket_name":"dfive-default","mount_path":"/weka/dfive-default"}' --image_name favyen/rslphelios13 +python -m rslp.main common beaker_train --config_paths+=data/helios/v2_satlas_marine_infra_128/basecfg.yaml --config_paths+=data/helios/v2_satlas_marine_infra_128/basecfg_sp_ts.yaml --config_paths+=data/helios/v2_shared/sp.yaml --project_id 2025_06_06_helios_finetuning --experiment_id v2_satlas_marine_infra_128_ts_sp --cluster+=ai2/ceres --cluster+=ai2/jupiter '--weka_mounts+={"bucket_name":"dfive-default","mount_path":"/weka/dfive-default"}' --image_name favyen/rslphelios13 ``` diff --git a/data/helios/v2_satlas_marine_infra_128/basecfg_sp.yaml b/data/helios/v2_satlas_marine_infra_128/basecfg_sp.yaml new file mode 100644 index 00000000..1dd37335 --- /dev/null +++ b/data/helios/v2_satlas_marine_infra_128/basecfg_sp.yaml @@ -0,0 +1,58 @@ +model: + init_args: + model: + init_args: + encoder: + - class_path: rslp.swin_pretrain.model.Model + init_args: + target_resolution_factor: null + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 3 + anchor_sizes: [[32], [64], [128], [256]] +data: + init_args: + inputs: + image: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslp.transforms.mask.Mask + train_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] diff --git a/data/helios/v2_satlas_marine_infra_128/basecfg_sp_ts.yaml b/data/helios/v2_satlas_marine_infra_128/basecfg_sp_ts.yaml new file mode 100644 index 00000000..f16a4454 --- /dev/null +++ b/data/helios/v2_satlas_marine_infra_128/basecfg_sp_ts.yaml @@ -0,0 +1,92 @@ +model: + init_args: + model: + init_args: + encoder: + - class_path: rslp.swin_pretrain.model.Model + init_args: + target_resolution_factor: null + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 3 + anchor_sizes: [[32], [64], [128], [256]] +data: + init_args: + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslp.transforms.mask.Mask + train_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] diff --git a/data/helios/v2_satlas_solar_farm_128/README.md b/data/helios/v2_satlas_solar_farm_128/README.md index e3ec426c..c1794b42 100644 --- a/data/helios/v2_satlas_solar_farm_128/README.md +++ b/data/helios/v2_satlas_solar_farm_128/README.md @@ -35,4 +35,8 @@ python -m rslp.main helios launch_finetune --helios_checkpoint_path /weka/dfive- # Helios with max pooling. python -m rslp.main helios launch_finetune --helios_checkpoint_path /weka/dfive-default/helios/checkpoints/joer/v0.1_base_latent_mim_space_time/step165000 --patch_size 8 --encoder_embedding_size 768 --image_name favyen/rslphelios2 --config_paths+=data/helios/v2_satlas_solar_farm_128/basecfg.yaml --config_paths+=data/helios/v2_satlas_solar_farm_128/basecfg_helios_ts.yaml --config_paths+=data/helios/v2_shared/helios_freeze_then_lowlr.yaml --config_paths+=data/helios/v2_shared/helios_ts_simple_maxpool_sentinel2.yaml --cluster+=ai2/ceres-cirrascale --cluster+=ai2/saturn-cirrascale --rslp_project 2025_06_06_helios_finetuning --experiment_id v2_satlas_solar_farm_128_ts_maxpool_helios_latent_mim_space_time + +# sp +python -m rslp.main common beaker_train --config_paths+=data/helios/v2_satlas_solar_farm_128/basecfg.yaml --config_paths+=data/helios/v2_satlas_solar_farm_128/basecfg_sp.yaml --config_paths+=data/helios/v2_shared/sp.yaml --project_id 2025_06_06_helios_finetuning --experiment_id v2_satlas_solar_farm_128_uni_sp --cluster+=ai2/ceres --cluster+=ai2/jupiter '--weka_mounts+={"bucket_name":"dfive-default","mount_path":"/weka/dfive-default"}' --image_name favyen/rslphelios13 +python -m rslp.main common beaker_train --config_paths+=data/helios/v2_satlas_solar_farm_128/basecfg.yaml --config_paths+=data/helios/v2_satlas_solar_farm_128/basecfg_sp_ts.yaml --config_paths+=data/helios/v2_shared/sp.yaml --project_id 2025_06_06_helios_finetuning --experiment_id v2_satlas_solar_farm_128_ts_sp --cluster+=ai2/ceres --cluster+=ai2/jupiter '--weka_mounts+={"bucket_name":"dfive-default","mount_path":"/weka/dfive-default"}' --image_name favyen/rslphelios13 ``` diff --git a/data/helios/v2_satlas_solar_farm_128/basecfg_sp.yaml b/data/helios/v2_satlas_solar_farm_128/basecfg_sp.yaml new file mode 100644 index 00000000..20f2ddb6 --- /dev/null +++ b/data/helios/v2_satlas_solar_farm_128/basecfg_sp.yaml @@ -0,0 +1,59 @@ +model: + init_args: + model: + init_args: + encoder: + - class_path: rslp.swin_pretrain.model.Model + decoders: + segment: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 2 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead +data: + init_args: + inputs: + image: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + is_target: true + targets: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + dtype: INT32 + is_target: true + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslp.transforms.mask.Mask + train_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image", "target/segment/classes", "target/segment/valid"] diff --git a/data/helios/v2_satlas_solar_farm_128/basecfg_sp_ts.yaml b/data/helios/v2_satlas_solar_farm_128/basecfg_sp_ts.yaml new file mode 100644 index 00000000..d303c749 --- /dev/null +++ b/data/helios/v2_satlas_solar_farm_128/basecfg_sp_ts.yaml @@ -0,0 +1,93 @@ +model: + init_args: + model: + init_args: + encoder: + - class_path: rslp.swin_pretrain.model.Model + decoders: + segment: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 2 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead +data: + init_args: + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + is_target: true + targets: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + dtype: INT32 + is_target: true + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslp.transforms.mask.Mask + train_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image", "target/segment/classes", "target/segment/valid"] diff --git a/data/helios/v2_sentinel2_vessel_attribute/README.md b/data/helios/v2_sentinel2_vessel_attribute/README.md index 9745a1b2..3b852b64 100644 --- a/data/helios/v2_sentinel2_vessel_attribute/README.md +++ b/data/helios/v2_sentinel2_vessel_attribute/README.md @@ -20,4 +20,7 @@ python -m rslp.main common beaker_train --config_paths+=data/helios/v2_sentinel2 # Helios python -m rslp.main helios launch_finetune --helios_checkpoint_path /weka/dfive-default/helios/checkpoints/joer/v0.1_base_latent_mim_space_time/step165000 --patch_size 8 --encoder_embedding_size 768 --image_name favyen/rslphelios2 --config_paths+=data/helios/v2_sentinel2_vessel_attribute/basecfg.yaml --config_paths+=data/helios/v2_sentinel2_vessel_attribute/basecfg_helios.yaml --config_paths+=data/helios/v2_shared/helios_freeze_then_lowlr.yaml --cluster+=ai2/ceres-cirrascale --cluster+=ai2/saturn-cirrascale --rslp_project 2025_06_06_helios_finetuning --experiment_id v2_sentinel2_vessel_attribute_helios_latent_mim_space_time + +# sp +python -m rslp.main common beaker_train --config_paths+=data/helios/v2_sentinel2_vessel_attribute/basecfg.yaml --config_paths+=data/helios/v2_sentinel2_vessel_attribute/basecfg_sp.yaml --config_paths+=data/helios/v2_shared/sp.yaml --project_id 2025_06_06_helios_finetuning --experiment_id v2_sentinel2_vessel_attribute_sp --cluster+=ai2/ceres --cluster+=ai2/jupiter '--weka_mounts+={"bucket_name":"dfive-default","mount_path":"/weka/dfive-default"}' --image_name favyen/rslphelios13 ``` diff --git a/data/helios/v2_sentinel2_vessel_attribute/basecfg_sp.yaml b/data/helios/v2_sentinel2_vessel_attribute/basecfg_sp.yaml new file mode 100644 index 00000000..8376a8e1 --- /dev/null +++ b/data/helios/v2_sentinel2_vessel_attribute/basecfg_sp.yaml @@ -0,0 +1,83 @@ +model: + init_args: + model: + init_args: + encoder: + - class_path: rslp.swin_pretrain.model.Model + init_args: + target_resolution_factor: null + decoders: + length: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 1024 + out_channels: 1 + num_conv_layers: 2 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + width: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 1024 + out_channels: 1 + num_conv_layers: 2 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + speed: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 1024 + out_channels: 1 + num_conv_layers: 2 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + heading_x: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 1024 + out_channels: 1 + num_conv_layers: 2 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + heading_y: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 1024 + out_channels: 1 + num_conv_layers: 2 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + ship_type: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 1024 + out_channels: 9 + num_conv_layers: 2 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead +data: + init_args: + inputs: + image: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + info: + data_type: "vector" + layers: ["info"] + is_target: true + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + train_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslp.sentinel2_vessel_attribute.train.VesselAttributeFlip diff --git a/data/helios/v2_sentinel2_vessels_128/README.md b/data/helios/v2_sentinel2_vessels_128/README.md index 0196db98..5a233671 100644 --- a/data/helios/v2_sentinel2_vessels_128/README.md +++ b/data/helios/v2_sentinel2_vessels_128/README.md @@ -20,4 +20,7 @@ python -m rslp.main common beaker_train --config_paths+=data/helios/v2_sentinel2 # Helios python -m rslp.main helios launch_finetune --helios_checkpoint_path /weka/dfive-default/helios/checkpoints/joer/v0.1_base_latent_mim_space_time/step165000 --patch_size 8 --encoder_embedding_size 768 --image_name favyen/rslphelios2 --config_paths+=data/helios/v2_sentinel2_vessels_128/basecfg.yaml --config_paths+=data/helios/v2_sentinel2_vessels_128/basecfg_helios.yaml --config_paths+=data/helios/v2_shared/helios_freeze_then_lowlr.yaml --cluster+=ai2/ceres-cirrascale --cluster+=ai2/saturn-cirrascale --rslp_project 2025_06_06_helios_finetuning --experiment_id v2_sentinel2_vessels_128_helios_latent_mim_space_time + +# sp +python -m rslp.main common beaker_train --config_paths+=data/helios/v2_sentinel2_vessels_128/basecfg.yaml --config_paths+=data/helios/v2_sentinel2_vessels_128/basecfg_sp.yaml --config_paths+=data/helios/v2_shared/sp.yaml --project_id 2025_06_06_helios_finetuning --experiment_id v2_sentinel2_vessels_128_sp --cluster+=ai2/ceres --cluster+=ai2/jupiter '--weka_mounts+={"bucket_name":"dfive-default","mount_path":"/weka/dfive-default"}' --image_name favyen/rslphelios13 ``` diff --git a/data/helios/v2_sentinel2_vessels_128/basecfg_sp.yaml b/data/helios/v2_sentinel2_vessels_128/basecfg_sp.yaml new file mode 100644 index 00000000..cf6af70f --- /dev/null +++ b/data/helios/v2_sentinel2_vessels_128/basecfg_sp.yaml @@ -0,0 +1,58 @@ +model: + init_args: + model: + init_args: + encoder: + - class_path: rslp.swin_pretrain.model.Model + init_args: + target_resolution_factor: null + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] +data: + init_args: + inputs: + image: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslp.transforms.mask.Mask + train_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] diff --git a/data/helios/v2_shared/sp.yaml b/data/helios/v2_shared/sp.yaml new file mode 100644 index 00000000..8b6f5840 --- /dev/null +++ b/data/helios/v2_shared/sp.yaml @@ -0,0 +1,13 @@ +model: + init_args: + restore_config: + restore_path: /weka/dfive-default/rslearn-eai/projects/swin_pretrain/pretrain_4mod_1conv/checkpoints/epoch130.ckpt.bak + selector: ["state_dict"] + remap_prefixes: + - ["model.encoder.0.", "encoder.0."] +trainer: + callbacks+: + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 10 diff --git a/data/swin_pretrain/config_1conv.yaml b/data/swin_pretrain/config_1conv.yaml new file mode 100644 index 00000000..18940123 --- /dev/null +++ b/data/swin_pretrain/config_1conv.yaml @@ -0,0 +1,140 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.swin_pretrain.model.Model + decoders: + openstreetmap: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 31 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + cdl: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 257 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + worldcover: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 102 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + chm: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 1 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionHead + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 + optimizer: + class_path: rslearn.train.optimizer.AdamW + init_args: + lr: 0.00002 +data: + class_path: rslp.swin_pretrain.dataset.HeliosDataModule + init_args: + ds_path: /weka/dfive-default/helios/dataset/osm_sampling + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + openstreetmap: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 31 + metric_kwargs: + average: "micro" + zero_is_invalid: true + cdl: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 257 + metric_kwargs: + average: "micro" + zero_is_invalid: true + worldcover: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 102 + metric_kwargs: + average: "micro" + zero_is_invalid: true + chm: + class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask + init_args: + scale_factor: 0.1 + metric_mode: "l1" + nodata_value: 0 + input_mapping: + openstreetmap: + 10_openstreetmap_raster: "targets" + cdl: + 10_cdl: "targets" + worldcover: + 10_worldcover: "targets" + chm: + 10_wri_canopy_height_map: "targets" + input_modalities: + - "10_sentinel2_l2a_monthly" + target_modalities: + - "10_openstreetmap_raster" + - "10_cdl" + - "10_worldcover" + - "10_wri_canopy_height_map" + num_val_examples: 1024 + batch_size: 8 + num_workers: 32 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min +rslp_project: swin_pretrain +rslp_experiment: pretrain_4mod_1conv diff --git a/data/swin_pretrain/config_1conv_crossattn.yaml b/data/swin_pretrain/config_1conv_crossattn.yaml new file mode 100644 index 00000000..d413bee8 --- /dev/null +++ b/data/swin_pretrain/config_1conv_crossattn.yaml @@ -0,0 +1,140 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.swin_pretrain.model_crossattn.Model + decoders: + openstreetmap: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 31 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + cdl: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 257 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + worldcover: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 102 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + chm: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 1 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionHead + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 + optimizer: + class_path: rslearn.train.optimizer.AdamW + init_args: + lr: 0.00002 +data: + class_path: rslp.swin_pretrain.dataset.HeliosDataModule + init_args: + ds_path: /weka/dfive-default/helios/dataset/osm_sampling + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + openstreetmap: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 31 + metric_kwargs: + average: "micro" + zero_is_invalid: true + cdl: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 257 + metric_kwargs: + average: "micro" + zero_is_invalid: true + worldcover: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 102 + metric_kwargs: + average: "micro" + zero_is_invalid: true + chm: + class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask + init_args: + scale_factor: 0.1 + metric_mode: "l1" + nodata_value: 0 + input_mapping: + openstreetmap: + 10_openstreetmap_raster: "targets" + cdl: + 10_cdl: "targets" + worldcover: + 10_worldcover: "targets" + chm: + 10_wri_canopy_height_map: "targets" + input_modalities: + - "10_sentinel2_l2a_monthly" + target_modalities: + - "10_openstreetmap_raster" + - "10_cdl" + - "10_worldcover" + - "10_wri_canopy_height_map" + num_val_examples: 1024 + batch_size: 8 + num_workers: 32 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min +rslp_project: swin_pretrain +rslp_experiment: pretrain_4mod_1conv_crossattn diff --git a/data/swin_pretrain/config_1conv_vit.yaml b/data/swin_pretrain/config_1conv_vit.yaml new file mode 100644 index 00000000..0a3b864a --- /dev/null +++ b/data/swin_pretrain/config_1conv_vit.yaml @@ -0,0 +1,158 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "/weka/dfive-default/helios/checkpoints/yawenzzzz/lmim_cross_random_contrastive_decode_wc_osm_srtm/step220000" + selector: ["encoder"] + forward_kwargs: + patch_size: 8 + - class_path: rslearn.models.unet.UNetDecoder + init_args: + in_channels: [[8, 768]] + out_channels: 128 + conv_layers_per_resolution: 2 + target_resolution_factor: 1 + num_channels: + 4: 256 + 2: 128 + 1: 128 + decoders: + openstreetmap: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 31 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + cdl: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 257 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + worldcover: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 102 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + chm: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 1 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionHead + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 + optimizer: + class_path: rslearn.train.optimizer.AdamW + init_args: + lr: 0.00002 +data: + class_path: rslp.swin_pretrain.dataset.HeliosDataModule + init_args: + ds_path: /weka/dfive-default/helios/dataset/osm_sampling + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + openstreetmap: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 31 + metric_kwargs: + average: "micro" + zero_is_invalid: true + cdl: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 257 + metric_kwargs: + average: "micro" + zero_is_invalid: true + worldcover: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 102 + metric_kwargs: + average: "micro" + zero_is_invalid: true + chm: + class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask + init_args: + scale_factor: 0.1 + metric_mode: "l1" + nodata_value: 0 + input_mapping: + openstreetmap: + 10_openstreetmap_raster: "targets" + cdl: + 10_cdl: "targets" + worldcover: + 10_worldcover: "targets" + chm: + 10_wri_canopy_height_map: "targets" + input_modalities: + - "10_sentinel2_l2a_monthly" + target_modalities: + - "10_openstreetmap_raster" + - "10_cdl" + - "10_worldcover" + - "10_wri_canopy_height_map" + num_val_examples: 1024 + batch_size: 8 + num_workers: 32 + patch_size: 8 + min_size: 8 + max_size: 96 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min +rslp_project: swin_pretrain +rslp_experiment: pretrain_4mod_1conv_vit diff --git a/data/swin_pretrain/config_1conv_vit_random.yaml b/data/swin_pretrain/config_1conv_vit_random.yaml new file mode 100644 index 00000000..1b4b49d5 --- /dev/null +++ b/data/swin_pretrain/config_1conv_vit_random.yaml @@ -0,0 +1,159 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "/weka/dfive-default/helios/checkpoints/yawenzzzz/lmim_cross_random_contrastive_decode_wc_osm_srtm/step220000" + selector: ["encoder"] + forward_kwargs: + patch_size: 8 + random_initialization: true + - class_path: rslearn.models.unet.UNetDecoder + init_args: + in_channels: [[8, 768]] + out_channels: 128 + conv_layers_per_resolution: 2 + target_resolution_factor: 1 + num_channels: + 4: 256 + 2: 128 + 1: 128 + decoders: + openstreetmap: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 31 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + cdl: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 257 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + worldcover: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 102 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + chm: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 1 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionHead + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 + optimizer: + class_path: rslearn.train.optimizer.AdamW + init_args: + lr: 0.00002 +data: + class_path: rslp.swin_pretrain.dataset.HeliosDataModule + init_args: + ds_path: /weka/dfive-default/helios/dataset/osm_sampling + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + openstreetmap: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 31 + metric_kwargs: + average: "micro" + zero_is_invalid: true + cdl: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 257 + metric_kwargs: + average: "micro" + zero_is_invalid: true + worldcover: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 102 + metric_kwargs: + average: "micro" + zero_is_invalid: true + chm: + class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask + init_args: + scale_factor: 0.1 + metric_mode: "l1" + nodata_value: 0 + input_mapping: + openstreetmap: + 10_openstreetmap_raster: "targets" + cdl: + 10_cdl: "targets" + worldcover: + 10_worldcover: "targets" + chm: + 10_wri_canopy_height_map: "targets" + input_modalities: + - "10_sentinel2_l2a_monthly" + target_modalities: + - "10_openstreetmap_raster" + - "10_cdl" + - "10_worldcover" + - "10_wri_canopy_height_map" + num_val_examples: 1024 + batch_size: 8 + num_workers: 32 + patch_size: 8 + min_size: 8 + max_size: 96 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min +rslp_project: swin_pretrain +rslp_experiment: pretrain_4mod_1conv_vit_random diff --git a/data/swin_pretrain/config_2conv.yaml b/data/swin_pretrain/config_2conv.yaml new file mode 100644 index 00000000..02ce54bc --- /dev/null +++ b/data/swin_pretrain/config_2conv.yaml @@ -0,0 +1,160 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.swin_pretrain.model.Model + decoders: + openstreetmap: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 31 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + cdl: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 257 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + worldcover: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 102 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + chm: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 1 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionHead + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 + optimizer: + class_path: rslearn.train.optimizer.AdamW + init_args: + lr: 0.00002 +data: + class_path: rslp.swin_pretrain.dataset.HeliosDataModule + init_args: + ds_path: /weka/dfive-default/helios/dataset/osm_sampling + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + openstreetmap: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 31 + metric_kwargs: + average: "micro" + zero_is_invalid: true + cdl: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 257 + metric_kwargs: + average: "micro" + zero_is_invalid: true + worldcover: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 102 + metric_kwargs: + average: "micro" + zero_is_invalid: true + chm: + class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask + init_args: + scale_factor: 0.1 + metric_mode: "l1" + nodata_value: 0 + input_mapping: + openstreetmap: + 10_openstreetmap_raster: "targets" + cdl: + 10_cdl: "targets" + worldcover: + 10_worldcover: "targets" + chm: + 10_wri_canopy_height_map: "targets" + input_modalities: + - "10_sentinel2_l2a_monthly" + target_modalities: + - "10_openstreetmap_raster" + - "10_cdl" + - "10_worldcover" + - "10_wri_canopy_height_map" + num_val_examples: 1024 + batch_size: 8 + num_workers: 32 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min +rslp_project: swin_pretrain +rslp_experiment: pretrain_4mod_2conv diff --git a/data/swin_pretrain/config_3conv.yaml b/data/swin_pretrain/config_3conv.yaml new file mode 100644 index 00000000..8dddc6ae --- /dev/null +++ b/data/swin_pretrain/config_3conv.yaml @@ -0,0 +1,180 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.swin_pretrain.model.Model + decoders: + openstreetmap: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 31 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + cdl: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 257 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + worldcover: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 102 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + chm: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 128 + kernel_size: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 128 + out_channels: 1 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.models.pick_features.PickFeatures + init_args: + indexes: [0] + collapse: true + - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionHead + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 + optimizer: + class_path: rslearn.train.optimizer.AdamW + init_args: + lr: 0.00002 +data: + class_path: rslp.swin_pretrain.dataset.HeliosDataModule + init_args: + ds_path: /weka/dfive-default/helios/dataset/osm_sampling + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + openstreetmap: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 31 + metric_kwargs: + average: "micro" + zero_is_invalid: true + cdl: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 257 + metric_kwargs: + average: "micro" + zero_is_invalid: true + worldcover: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 102 + metric_kwargs: + average: "micro" + zero_is_invalid: true + chm: + class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask + init_args: + scale_factor: 0.1 + metric_mode: "l1" + nodata_value: 0 + input_mapping: + openstreetmap: + 10_openstreetmap_raster: "targets" + cdl: + 10_cdl: "targets" + worldcover: + 10_worldcover: "targets" + chm: + 10_wri_canopy_height_map: "targets" + input_modalities: + - "10_sentinel2_l2a_monthly" + target_modalities: + - "10_openstreetmap_raster" + - "10_cdl" + - "10_worldcover" + - "10_wri_canopy_height_map" + num_val_examples: 1024 + batch_size: 8 + num_workers: 32 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min +rslp_project: swin_pretrain +rslp_experiment: pretrain_4mod_3conv diff --git a/rslp/swin_pretrain/__init__.py b/rslp/swin_pretrain/__init__.py new file mode 100644 index 00000000..ff80b437 --- /dev/null +++ b/rslp/swin_pretrain/__init__.py @@ -0,0 +1 @@ +"""Components for pre-training Swin models on Helios dataset.""" diff --git a/rslp/swin_pretrain/dataset.py b/rslp/swin_pretrain/dataset.py new file mode 100644 index 00000000..a567a1e5 --- /dev/null +++ b/rslp/swin_pretrain/dataset.py @@ -0,0 +1,518 @@ +"""Load rslearn-compatible data from Helios dataset folder.""" + +import hashlib +import random +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import lightning as L +import numpy as np +import rasterio +import torch +from einops import rearrange +from rasterio.crs import CRS +from rslearn.train.data_module import collate_fn +from rslearn.train.tasks import Task +from rslearn.utils.geometry import Projection +from torch.utils.data import DataLoader, DistributedSampler + +from rslp.log_utils import get_logger + +logger = get_logger(__file__) + + +@dataclass +class ModalityInfo: + """Info about a modality.""" + + # Filename suffixes and the number of bands in that suffix. + # If there are multiple suffixes, the bands should be stacked. + suffixes: list[tuple[str, int]] + + # Number of overall bands. + # If num_bands = 1 but there are multiple bands across the suffixes, then it will + # be converted from one-hot encoding to integer class label. + # This corresponds to the number of classes for categorical modalities. + num_bands: int + + # How much to divide by for normalization. + # Only applies when using this modality as an input. + norm_factor: float | None = None + + is_multitemporal: bool = False + + +MODALITIES = { + "10_sentinel2_l2a_monthly": ModalityInfo( + suffixes=[ + ("_10.tif", 4), + ("_20.tif", 6), + ("_40.tif", 2), + ], + num_bands=12, + norm_factor=10000, + is_multitemporal=True, + ), + "10_openstreetmap_raster": ModalityInfo( + suffixes=[("_2.5.tif", 30)], + num_bands=1, + ), + "10_cdl": ModalityInfo( + suffixes=[("_10.tif", 1)], + num_bands=1, + ), + "10_worldcover": ModalityInfo( + suffixes=[("_10.tif", 1)], + num_bands=1, + ), + "10_wri_canopy_height_map": ModalityInfo( + suffixes=[("_10.tif", 1)], + num_bands=1, + ), +} +TILE_SIZE = 256 + + +class CollateFunction: + """Collate function for Helios dataset.""" + + def __init__( + self, + randomize: bool = True, + min_size: int = 256, + max_size: int = 256, + patch_size: int = 32, + ): + """Create a new CollateFunction. + + Args: + randomize: whether to randomize the selection of options like number of + timesteps or height/width for cropping. Should be true for training and + false for validation. + min_size: minimum size to crop the input. + max_size: maximum size to crop the input. + patch_size: ensure the cropped input is a multiple of this amount. + """ + self.randomize = randomize + self.min_size = min_size + self.max_size = max_size + self.patch_size = patch_size + + def __call__( + self, batch: list[tuple[dict[str, Any], dict[str, Any], dict[str, Any]]] + ) -> tuple: + """Collate batch of training examples. + + We just make list of the inputs and another of the targets. + + Args: + batch: list of input/target/metadata for each example + + Returns: + a tuple (inputs, targets, metadatas) + """ + inputs, targets, metadatas = collate_fn(batch) + + # Find minimum number of available timesteps. + multitemporal_modalities = [ + (modality, info) + for modality, info in MODALITIES.items() + if info.is_multitemporal + ] + minimum_available_timesteps: int | None = None + for input_dict in inputs: + for modality, info in multitemporal_modalities: + if modality not in input_dict: + continue + cur_timesteps = input_dict[modality].shape[0] // info.num_bands + if minimum_available_timesteps is None: + minimum_available_timesteps = cur_timesteps + else: + minimum_available_timesteps = min( + minimum_available_timesteps, cur_timesteps + ) + + # Randomly pick a subset of timesteps, along with spatial crop size. + assert minimum_available_timesteps is not None + if self.randomize: + num_timesteps = random.randint(1, minimum_available_timesteps) + crop_size = random.randint(self.min_size, self.max_size) + crop_size = (crop_size // self.patch_size) * self.patch_size + crop_h_start = random.randint(0, TILE_SIZE - crop_size) + crop_w_start = random.randint(0, TILE_SIZE - crop_size) + else: + rng = np.random.default_rng(hash(metadatas[0]["window_name"]) % 65536) + num_timesteps = rng.integers(minimum_available_timesteps) + 1 + crop_size = self.min_size + rng.integers(self.max_size - self.min_size + 1) + crop_size = (crop_size // self.patch_size) * self.patch_size + crop_h_start = 0 + crop_w_start = 0 + + # Temporal subset. + for input_dict in inputs: + for modality, info in multitemporal_modalities: + if modality not in input_dict: + continue + image = input_dict[modality] + # Reshape so the timesteps and bands are separate. + cur_timesteps = image.shape[0] // MODALITIES[modality].num_bands + image = rearrange( + image, + "(t c) h w -> t c h w", + t=cur_timesteps, + c=MODALITIES[modality].num_bands, + ) + # Subset the timesteps. + available_timesteps = list(range(image.shape[0])) + if self.randomize: + selected_timesteps = random.sample( + available_timesteps, num_timesteps + ) + else: + selected_timesteps = available_timesteps[0:num_timesteps] + image = image[sorted(selected_timesteps)] + # Reshape back so the timesteps and bands are stacked. + image = rearrange(image, "t c h w -> (t c) h w") + input_dict[modality] = image + + # TODO + input_dict["image"] = input_dict["10_sentinel2_l2a_monthly"] + + # Spatial crop. + for input_dict in inputs: + for modality in list(input_dict.keys()): + input_dict[modality] = input_dict[modality][ + :, + crop_h_start : crop_h_start + crop_size, + crop_w_start : crop_w_start + crop_size, + ] + for target_dict in targets: + for task_name in list(target_dict.keys()): + for sub_name in list(target_dict[task_name].keys()): + image = target_dict[task_name][sub_name] + if len(image.shape) == 2: + image = image[ + crop_h_start : crop_h_start + crop_size, + crop_w_start : crop_w_start + crop_size, + ] + else: + image = image[ + :, + crop_h_start : crop_h_start + crop_size, + crop_w_start : crop_w_start + crop_size, + ] + target_dict[task_name][sub_name] = image + + return (inputs, targets, metadatas) + + +class HeliosDataset(torch.utils.data.Dataset): + """A dataset for Helios data.""" + + def __init__( + self, + ds_path: Path, + task: Task, + input_modalities: list[str], + target_modalities: list[str], + limit: int | None = None, + skip: int = 0, + ): + """Create a new HeliosDataset. + + Args: + ds_path: the path to the Helios dataset folder. + task: the task to train on. + input_modalities: list of modalities to input. + target_modalities: list of modalities to use as targets. + limit: limit to this many samples + skip: skip this many initial samples. + """ + self.ds_path = ds_path + self.task = task + self.input_modalities = input_modalities + self.target_modalities = target_modalities + + # Get the unique tiles that have at least one of the input modalities. + tile_set = set() + for modality in self.input_modalities: + modality_info = MODALITIES[modality] + suffix, _ = modality_info.suffixes[0] + modality_dir = ds_path / modality + logger.info(f"Getting tiles in {modality_dir}") + for fname in modality_dir.iterdir(): + if not fname.name.endswith(suffix): + continue + tile_name = fname.name[: -len(suffix)] + tile_set.add(tile_name) + + logger.info(f"Discovered {len(tile_set)} tiles total") + self.tile_list = list(tile_set) + self.tile_list.sort( + key=lambda tile_name: hashlib.sha256(tile_name.encode()).hexdigest() + ) + + self.tile_list = self.tile_list[skip:] + if limit is not None: + self.tile_list = self.tile_list[0:limit] + + logger.info(f"Finishing setup with {len(self.tile_list)} tiles") + + def __len__(self) -> int: + """Get the length of the dataset.""" + return len(self.tile_list) + + def _load_modality(self, tile_name: str, modality: str) -> torch.Tensor | None: + info = MODALITIES[modality] + + # Get list of images across suffixes. + # Each image is TxCxHxW for multitemporal, otherwise CxHxW. + image_list = [] + + for suffix, suffix_bands in info.suffixes: + fname = self.ds_path / modality / (tile_name + suffix) + if not fname.exists(): + return None + + with rasterio.open(fname) as src: + array = torch.from_numpy(src.read()) + + # Resize to TILE_SIZE. + if array.shape[1] < TILE_SIZE: + factor = TILE_SIZE // array.shape[1] + if ( + array.shape[1] * factor != TILE_SIZE + or array.shape[2] * factor != TILE_SIZE + ): + raise ValueError(f"bad array shape {array.shape}") + array = torch.repeat_interleave(array, repeats=factor, dim=1) + array = torch.repeat_interleave(array, repeats=factor, dim=2) + elif array.shape[1] > TILE_SIZE: + factor = array.shape[1] // TILE_SIZE + if ( + array.shape[1] != TILE_SIZE * factor + or array.shape[2] != TILE_SIZE * factor + ): + raise ValueError(f"bad array shape {array.shape}") + # Use max pool since it works better for categorical modalities (to not + # lose the finer-grained detail). + array = torch.nn.functional.max_pool2d( + array, kernel_size=factor, stride=factor + ) + + # Convert stacked timesteps. + if info.is_multitemporal: + if array.shape[0] % suffix_bands != 0: + raise ValueError( + f"array has {array.shape[0]} bands but that is not a multiple of {suffix_bands}" + ) + num_timesteps = array.shape[0] // suffix_bands + array = array.reshape( + (num_timesteps, suffix_bands, array.shape[1], array.shape[2]) + ) + + elif array.shape[0] != suffix_bands: + raise ValueError( + f"non-multi-temporal array {array.shape[0]} != {suffix_bands}" + ) + + image_list.append(array) + + # Stack the bands on channel axis. + if info.is_multitemporal: + image = torch.cat(image_list, dim=1) + # Change timesteps and bands back to being combined. + image = image.reshape((-1, image.shape[2], image.shape[3])) + else: + image = torch.cat(image_list, dim=0) + + # Convert single-band via one-hot encoding if needed. + # This only works for non-multi-temporal data. + if image.shape[0] > 1 and info.num_bands == 1: + image = image.argmax(dim=0, keepdim=True) + + if info.norm_factor is not None: + image = (image / info.norm_factor).to(dtype=torch.float32) + else: + image = image.to(dtype=torch.long) + + return image + + def __getitem__( + self, idx: int + ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Get the dataset item at the specified index.""" + tile_name = self.tile_list[idx] + raw_inputs = {} + passthrough_inputs = {} + + # Currently all the inputs probably need to be multi-temporal while all the + # others need to not be. This may not be ideal. + # TODO: sub-sample from available input modalities. Need to adjust model to + # accept different subsets of inputs. + for modality in self.input_modalities: + passthrough_inputs[modality] = self._load_modality(tile_name, modality) + for modality in self.target_modalities: + # For the targets we add one if they are present, otherwise set to zero. + image = self._load_modality(tile_name, modality) + if image is None: + image = torch.zeros((1, TILE_SIZE, TILE_SIZE), dtype=torch.long) + else: + image = image + 1 + raw_inputs[modality] = image + + metadata = { + "group": "fake", + "window_name": tile_name, + "window_bounds": (0, 0, TILE_SIZE, TILE_SIZE), + "bounds": (0, 0, TILE_SIZE, TILE_SIZE), + "time_range": ( + datetime(2024, 1, 1, tzinfo=UTC), + datetime(2024, 2, 1, tzinfo=UTC), + ), + "projection": Projection(CRS.from_epsg(32610), 10, -10), + "dataset_source": None, + "patch_idx": 0, + "num_patches": 1, + } + + input_dict, target_dict = self.task.process_inputs( + raw_inputs, + metadata=metadata, + load_targets=True, + ) + input_dict.update(passthrough_inputs) + # input_dict, target_dict = self.transforms(input_dict, target_dict) + + return input_dict, target_dict, metadata + + +class HeliosDataModule(L.LightningDataModule): + """Data module for Helios data.""" + + def __init__( + self, + ds_path: Path, + task: Task, + input_modalities: list[str], + target_modalities: list[str], + num_val_examples: int, + batch_size: int, + num_workers: int, + min_size: int = 256, + max_size: int = 256, + patch_size: int = 32, + ) -> None: + """Initialize a new DataModule.""" + super().__init__() + self.ds_path = ds_path + self.task = task + self.input_modalities = input_modalities + self.target_modalities = target_modalities + self.num_val_examples = num_val_examples + self.batch_size = batch_size + self.num_workers = num_workers + self.min_size = min_size + self.max_size = max_size + self.patch_size = patch_size + + def setup(self, stage: str) -> None: + """Set up datasets and samplers. + + Args: + stage: Either 'fit' or 'validate' + """ + # Setup training dataset only for fit command. + if stage == "fit": + self.train_dataset = HeliosDataset( + ds_path=self.ds_path, + task=self.task, + input_modalities=self.input_modalities, + target_modalities=self.target_modalities, + skip=self.num_val_examples, + ) + + # Setup validation dataset. + self.val_dataset = HeliosDataset( + ds_path=self.ds_path, + task=self.task, + input_modalities=self.input_modalities, + target_modalities=self.target_modalities, + limit=self.num_val_examples, + ) + + def train_dataloader(self) -> DataLoader[dict[str, torch.Tensor]]: + """Implement one or more PyTorch DataLoaders for training. + + Returns: + A collection of data loaders specifying training samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + dataset or sampler, or if the dataset or sampler has length 0. + """ + kwargs = dict( + dataset=self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=CollateFunction( + randomize=True, + min_size=self.min_size, + max_size=self.max_size, + patch_size=self.patch_size, + ), + persistent_workers=True, + ) + if ( + self.trainer is not None + and self.trainer.world_size is not None + and self.trainer.world_size > 1 + ): + kwargs["sampler"] = DistributedSampler( + self.train_dataset, + num_replicas=self.trainer.world_size, + rank=self.trainer.global_rank, + shuffle=True, + ) + else: + kwargs["shuffle"] = True + + return DataLoader(**kwargs) + + def val_dataloader(self) -> DataLoader[dict[str, torch.Tensor]]: + """Implement one or more PyTorch DataLoaders for validation. + + Returns: + A collection of data loaders specifying validation samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + dataset or sampler, or if the dataset or sampler has length 0. + """ + kwargs = dict( + dataset=self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=CollateFunction( + randomize=False, + min_size=self.min_size, + max_size=self.max_size, + patch_size=self.patch_size, + ), + persistent_workers=True, + ) + if ( + self.trainer is not None + and self.trainer.world_size is not None + and self.trainer.world_size > 1 + ): + kwargs["sampler"] = DistributedSampler( + self.val_dataset, + num_replicas=self.trainer.world_size, + rank=self.trainer.global_rank, + shuffle=False, + ) + + return DataLoader(**kwargs) diff --git a/rslp/swin_pretrain/model.py b/rslp/swin_pretrain/model.py new file mode 100644 index 00000000..182f5a99 --- /dev/null +++ b/rslp/swin_pretrain/model.py @@ -0,0 +1,55 @@ +"""Model for pre-training Swin backbone on Helios dataset.""" + +from typing import Any + +import torch +from rslearn.models.simple_time_series import SimpleTimeSeries +from rslearn.models.swin import Swin +from rslearn.models.unet import UNetDecoder + + +class Model(torch.nn.Module): + """Model for pre-training.""" + + def __init__( + self, + target_resolution_factor: int | None = 1, + unet_out_channels: int | None = 128, + ) -> None: + """Initialize the model.""" + super().__init__() + self.target_resolution_factor = target_resolution_factor + # Currently this model can only handle one input image (Sentinel-2). + self.backbone = SimpleTimeSeries( + encoder=Swin( + arch="swin_v2_b", + pretrained=True, + input_channels=12, + output_layers=[1, 3, 5, 7], + ), + image_channels=12, + ) + if self.target_resolution_factor is not None: + self.unet = UNetDecoder( + in_channels=[[4, 128], [8, 256], [16, 512], [32, 1024]], + out_channels=unet_out_channels, + conv_layers_per_resolution=2, + target_resolution_factor=target_resolution_factor, + ) + + def forward( + self, + inputs: list[dict[str, Any]], + ) -> list[torch.Tensor]: + """Compute outputs from the wrapped module. + + Inputs: + inputs: input dicts that must include "image" key containing the image to + process. + """ + features = self.backbone(inputs) + if self.target_resolution_factor is None: + return features + + hr_features = self.unet(features, None) + return [hr_features] diff --git a/rslp/swin_pretrain/model_crossattn.py b/rslp/swin_pretrain/model_crossattn.py new file mode 100644 index 00000000..fe159355 --- /dev/null +++ b/rslp/swin_pretrain/model_crossattn.py @@ -0,0 +1,222 @@ +"""Model for pre-training Swin backbone on Helios dataset. + +This model uses a cross-attention mechanism for temporal pooling instead of max +pooling. Temporal pooling is needed since the Swin component is applied on each image +in the time series. +""" + +from typing import Any + +import torch +import torch.nn as nn +from einops import rearrange +from rslearn.models.swin import Swin +from rslearn.models.unet import UNetDecoder + + +def _get_1d_sincos_pos_embed( + embed_dim: int, length: int, device: torch.device +) -> torch.Tensor: + """Return [length, embed_dim] 1D sinusoidal PE.""" + assert embed_dim % 2 == 0, "embed_dim must be divisible by 2 for sin/cos." + pos = torch.arange(length, device=device, dtype=torch.float32) # [L] + dim = torch.arange(embed_dim // 2, device=device, dtype=torch.float32) # [D/2] + freqs = 1.0 / (10000 ** (dim / (embed_dim // 2))) + angles = pos[:, None] * freqs[None, :] # [L, D/2] + emb = torch.cat([angles.sin(), angles.cos()], dim=1) # [L, D] + return emb # [L, D] + + +def _get_2d_sincos_pos_embed( + embed_dim: int, h: int, w: int, device: torch.device +) -> torch.Tensor: + """Return [h*w, embed_dim] 2D sinusoidal PE (sum of 1D encodings over x/y).""" + assert embed_dim % 2 == 0, "embed_dim must be divisible by 2 for 2D sin/cos." + half_dim = embed_dim // 2 + pe_h = _get_1d_sincos_pos_embed(half_dim, h, device) # [H, D/2] + pe_w = _get_1d_sincos_pos_embed(half_dim, w, device) # [W, D/2] + # meshgrid then concat -> sum as in ViT-style 2D pe + y, x = torch.meshgrid( + torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij" + ) + pe = torch.cat([pe_h[y], pe_w[x]], dim=2) # [H, W, D] + pe = pe.view(h * w, embed_dim) # [HW, D] + return pe + + +class CrossAttentionTemporalPool(nn.Module): + """Temporal pooling via cross-attention with one dst token per spatial patch. + + Inputs: x of shape [B, T, C, H, W] + Outputs: pooled of shape [B, C, H, W] + """ + + def __init__( + self, + embed_dim: int, + num_heads: int = 8, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + ) -> None: + """Create a new crossAttentionTemporalPool. + + Args: + embed_dim: the embedding dimension of the inputs. + num_heads: number of attention heads. + mlp_ratio: ratio of MLP layer embedding dimension over embed_dim. + dropout: how much to dropout, default to not dropout. + """ + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.norm_q = nn.LayerNorm(embed_dim) + self.norm_kv = nn.LayerNorm(embed_dim) + self.attn = nn.MultiheadAttention( + embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True + ) + self.drop = nn.Dropout(dropout) + + hidden_dim = int(embed_dim * mlp_ratio) + self.mlp = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, embed_dim), + nn.Dropout(dropout), + ) + + @torch.no_grad() + def _make_pos( + self, T: int, H: int, W: int, device: torch.device + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute spatial [HW, C] and temporal [T, C] sinusoidal embeddings.""" + pe_spatial = _get_2d_sincos_pos_embed(self.embed_dim, H, W, device) # [HW, C] + pe_time = _get_1d_sincos_pos_embed(self.embed_dim, T, device) # [T, C] + return pe_spatial, pe_time + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Model forward pass. + + Args; + x: the input feature tensor, [B, T, C, H, W] + + Returns; + an output feature tensor [B, C, H, W] + """ + B, T, C, H, W = x.shape + assert C == self.embed_dim, f"Channel dim {C} != embed_dim {self.embed_dim}" + + device = x.device + N = H * W + + # Positional encodings + pe_spatial, pe_time = self._make_pos(T, H, W, device) # [N, C], [T, C] + + # Prepare src tokens: one per (timestep, patch) + # Add spatial+time PE to each src token + x_tokens = rearrange(x, "b t c h w -> b t (h w) c") # [B, T, N, C] + x_tokens = x_tokens + pe_time.view(1, T, 1, C) + pe_spatial.view(1, 1, N, C) + x_src = rearrange(x_tokens, "b t n c -> (b n) t c") # [B*N, T, C] + x_src = self.norm_kv(x_src) # pre-norm for KV + + # Prepare dst queries: one per spatial patch; use only spatial PE (content-free query) + q = pe_spatial.view(1, N, C).expand(B, N, C) # [B, N, C] + q = rearrange(q, "b n c -> (b n) 1 c") # [B*N, 1, C] + q = self.norm_q(q) + + # Cross-attention: Q (1 token) attends over src sequence of length T for each patch + out, _ = self.attn(q, x_src, x_src, need_weights=False) # [B*N, 1, C] + out = q + self.drop(out) # residual + out = out + self.mlp(out) # MLP block + out = out.squeeze(1) # [B*N, C] + + # Back to [B, C, H, W] + out = rearrange(out, "(b n) c -> b c n", b=B, n=N) + out = out.view(B, C, H, W) + return out + + +class Model(torch.nn.Module): + """Model for pre-training.""" + + def __init__( + self, + target_resolution_factor: int | None = 1, + unet_out_channels: int | None = 128, + cross_attn_heads: int = 8, + cross_attn_mlp_ratio: float = 4.0, + cross_attn_dropout: float = 0.0, + ) -> None: + """Initialize the model.""" + super().__init__() + self.target_resolution_factor = target_resolution_factor + # Currently this model can only handle one input image (Sentinel-2). + self.encoder = Swin( + arch="swin_v2_b", + pretrained=True, + input_channels=12, + output_layers=[1, 3, 5, 7], + ) + encoder_channels = [128, 256, 512, 1024] + self.temporal_poolers = nn.ModuleList( + [ + CrossAttentionTemporalPool( + embed_dim=c, + num_heads=cross_attn_heads, + mlp_ratio=cross_attn_mlp_ratio, + dropout=cross_attn_dropout, + ) + for c in encoder_channels + ] + ) + + if self.target_resolution_factor is not None: + self.unet = UNetDecoder( + in_channels=[[4, 128], [8, 256], [16, 512], [32, 1024]], + out_channels=unet_out_channels, + conv_layers_per_resolution=2, + target_resolution_factor=target_resolution_factor, + ) + + def forward( + self, + inputs: list[dict[str, Any]], + ) -> list[torch.Tensor]: + """Compute outputs from the wrapped module. + + Inputs: + inputs: input dicts that must include "image" key containing the image to + process. + """ + # Apply the image encoder on each image in the time series. + images = torch.stack([inp["image"] for inp in inputs], dim=0) + image_channels = 12 + batch_size = len(inputs) + assert images.shape[1] % image_channels == 0 + n_images = images.shape[1] // image_channels + # Reshape images to B*T x C x H x W. + images = rearrange( + images, "b (t c) h w -> (b t) c h w", t=n_images, c=image_channels + ) + # Now add "image" key expected by encoder. + batched_inputs = [{"image": image} for image in images] + # Encoder provides one feature map per resolution. + encoder_feats = self.encoder(batched_inputs) + all_features = [ + rearrange(feat_map, "(b t) c h w -> b t c h w", b=batch_size, t=n_images) + for feat_map in encoder_feats + ] + + # Compute pooled features using cross attention. + pooled_features = [ + pool(feat_map) + for pool, feat_map in zip(self.temporal_poolers, all_features) + ] + + if self.target_resolution_factor is None: + return pooled_features + + hr_features = self.unet(pooled_features, None) + return [hr_features] diff --git a/tests/unit/swin_pretrain/__init__.py b/tests/unit/swin_pretrain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/swin_pretrain/test_dataset.py b/tests/unit/swin_pretrain/test_dataset.py new file mode 100644 index 00000000..1110a7b6 --- /dev/null +++ b/tests/unit/swin_pretrain/test_dataset.py @@ -0,0 +1,173 @@ +"""Unit tests for rslp.swin_pretrain.dataset.""" + +from typing import Any + +import torch + +from rslp.swin_pretrain.dataset import TILE_SIZE, CollateFunction + + +class TestCollateFunction: + """Tests for CollateFunction.""" + + def make_example( + self, + height: int, + width: int, + input_modalities: dict[str, int], + segment_targets: list[str], + example_idx: int, + timesteps: int = 1, + ) -> tuple[dict, dict, dict]: + """Make an example to pass to CollateFunction.""" + input_dict: dict[str, torch.Tensor] = {} + for modality, num_bands in input_modalities.items(): + image = torch.zeros( + (timesteps * num_bands, height, width), dtype=torch.float32 + ) + # Set each timestep to different value so tests can distinguish them. + for timestep in range(timesteps): + image[timestep * num_bands : (timestep + 1) * num_bands] = timestep + input_dict[modality] = image + target_dict: dict[str, Any] = {} + for modality in segment_targets: + target_dict[modality] = { + "classes": torch.zeros((height, width), dtype=torch.int32), + "valid": torch.zeros((height, width), dtype=torch.int32), + } + metadata = { + "window_name": f"window{example_idx}", + } + return input_dict, target_dict, metadata + + def test_random_cropping(self) -> None: + """Test that we get random crops of different sizes.""" + min_size = 16 + max_size = 48 + patch_size = 8 + collate_fn = CollateFunction( + randomize=True, + min_size=min_size, + max_size=max_size, + patch_size=patch_size, + ) + widths = set() + for _ in range(10): + batch = [ + self.make_example( + height=TILE_SIZE, + width=TILE_SIZE, + input_modalities={"10_sentinel2_l2a_monthly": 12}, + segment_targets=["10_worldcover"], + example_idx=0, + ), + self.make_example( + height=TILE_SIZE, + width=TILE_SIZE, + input_modalities={"10_sentinel2_l2a_monthly": 12}, + segment_targets=["10_worldcover"], + example_idx=1, + ), + ] + inputs, targets, _ = collate_fn(batch) + # Make sure all examples in the batch have the same shape. + assert ( + inputs[0]["10_sentinel2_l2a_monthly"].shape + == inputs[1]["10_sentinel2_l2a_monthly"].shape + ) + assert ( + targets[0]["10_worldcover"]["classes"].shape + == targets[1]["10_worldcover"]["classes"].shape + ) + # Make sure within an example it has the same height/width. + assert ( + inputs[0]["10_sentinel2_l2a_monthly"].shape[1:3] + == targets[0]["10_worldcover"]["classes"].shape[0:2] + ) + # Make sure it is square and a multiple of the requested patch size. + height, width = inputs[0]["10_sentinel2_l2a_monthly"].shape[1:3] + assert height == width + assert width % patch_size == 0 + assert width >= min_size and width <= max_size + widths.add(width) + + # Make sure we got at least two unique widths. + assert len(widths) >= 2 + + def test_non_random(self) -> None: + """Verify that the same window name always is cropped the same way.""" + min_size = 16 + max_size = 48 + patch_size = 8 + collate_fn = CollateFunction( + randomize=False, + min_size=min_size, + max_size=max_size, + patch_size=patch_size, + ) + widths: list[set[int]] = [set() for _ in range(10)] + for _ in range(3): + for example_idx in range(10): + batch = [ + self.make_example( + height=TILE_SIZE, + width=TILE_SIZE, + input_modalities={"10_sentinel2_l2a_monthly": 12}, + segment_targets=["10_worldcover"], + example_idx=example_idx, + ) + ] + inputs, _, _ = collate_fn(batch) + width = inputs[0]["10_sentinel2_l2a_monthly"].shape[2] + assert width % patch_size == 0 + widths[example_idx].add(width) + + # Make sure each example has same width across batches. + all_widths = set() + for width_set in widths: + assert len(width_set) == 1 + all_widths.update(width_set) + + # Make sure we got at least two unique widths. + assert len(all_widths) >= 2 + + def test_temporal_subset(self) -> None: + """Verify that we get different timesteps but they are always in order.""" + min_size = 8 + max_size = 8 + patch_size = 8 + collate_fn = CollateFunction( + randomize=True, + min_size=min_size, + max_size=max_size, + patch_size=patch_size, + ) + first_timesteps_set = set() + num_timesteps_set = set() + for _ in range(8): + batch = [ + self.make_example( + height=TILE_SIZE, + width=TILE_SIZE, + input_modalities={"10_sentinel2_l2a_monthly": 12}, + segment_targets=[], + example_idx=0, + timesteps=4, + ) + ] + inputs, _, _ = collate_fn(batch) + image = inputs[0]["10_sentinel2_l2a_monthly"] + num_timesteps = image.shape[0] // 12 + + # Verify order of timesteps. + selected_timesteps = [] + for timestep in range(num_timesteps): + selected_timesteps.append(image[timestep * 12, 0, 0]) + if len(selected_timesteps) >= 2: + assert selected_timesteps[-1] > selected_timesteps[-2] + + num_timesteps_set.add(num_timesteps) + first_timesteps_set.add(selected_timesteps[0]) + + # Make sure we got at least two different first timesteps. + assert len(first_timesteps_set) >= 2 diff --git a/tests/unit/swin_pretrain/test_model_crossattn.py b/tests/unit/swin_pretrain/test_model_crossattn.py new file mode 100644 index 00000000..9f3722ca --- /dev/null +++ b/tests/unit/swin_pretrain/test_model_crossattn.py @@ -0,0 +1,25 @@ +"""Unit tests for rslp.swin_pretrain.model_crossattn.""" + +import torch + +from rslp.swin_pretrain.model_crossattn import CrossAttentionTemporalPool + + +class TestCrossAttentionTemporalPool: + """Test CrossAttentionTemporalPool.""" + + def test_single_timestep(self) -> None: + embed_dim = 64 + size = 8 + pool = CrossAttentionTemporalPool(embed_dim=embed_dim) + x = torch.zeros((1, 1, embed_dim, size, size)) + result = pool(x) + assert result.shape == (1, embed_dim, size, size) + + def test_multiple_timesteps(self) -> None: + embed_dim = 64 + size = 8 + pool = CrossAttentionTemporalPool(embed_dim=embed_dim) + x = torch.zeros((1, 4, embed_dim, size, size)) + result = pool(x) + assert result.shape == (1, embed_dim, size, size)