diff --git a/.github/workflows/forest_loss_driver_prediction.yaml b/.github/workflows/forest_loss_driver_prediction.yaml index 7ecee887..87feb9b5 100644 --- a/.github/workflows/forest_loss_driver_prediction.yaml +++ b/.github/workflows/forest_loss_driver_prediction.yaml @@ -99,7 +99,7 @@ jobs: "bucket_name": "dfive-default", "mount_path": "/dfive-default" }] - RSLP_PREFIX: ${{ secrets.RSLP_PREFIX }} + RSLP_PREFIX: "/dfive-default/rslearn-eai/" TASK_ENV_VARS: | [{ "class_path": "beaker.BeakerEnvVar", diff --git a/data/forest_loss_driver/config.json b/data/forest_loss_driver/20240912/config.json similarity index 100% rename from data/forest_loss_driver/config.json rename to data/forest_loss_driver/20240912/config.json diff --git a/data/forest_loss_driver/config.yaml b/data/forest_loss_driver/20240912/config.yaml similarity index 100% rename from data/forest_loss_driver/config.yaml rename to data/forest_loss_driver/20240912/config.yaml diff --git a/data/forest_loss_driver/20251104/config.json b/data/forest_loss_driver/20251104/config.json new file mode 100644 index 00000000..9c08b973 --- /dev/null +++ b/data/forest_loss_driver/20251104/config.json @@ -0,0 +1,225 @@ +{ + "layers": { + "best_post_0": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "best_post_1": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "best_post_2": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "best_pre_0": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "best_pre_1": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "best_pre_2": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "label": { + "class_names": [ + "agriculture", + "agriculture-generic", + "agriculture-mennonite", + "agriculture-rice", + "agriculture-small", + "airstrip", + "burned", + "coca", + "flood", + "human", + "hurricane", + "landslide", + "logging", + "mining", + "natural", + "none", + "river", + "road", + "unknown", + "unlabeled" + ], + "class_property_name": "new_label", + "type": "vector" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "mask_vector": { + "type": "vector" + }, + "output": { + "type": "vector" + }, + "post_sentinel2": { + "alias": "sentinel2", + "band_sets": [ + { + "bands": [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B11", + "B12" + ], + "dtype": "uint16" + } + ], + "data_source": { + "cache_dir": "cache/planetary_computer", + "duration": "180d", + "harmonize": true, + "ingest": false, + "name": "rslearn.data_sources.planetary_computer.Sentinel2", + "query_config": { + "max_matches": 4, + "space_mode": "CONTAINS" + }, + "sort_by": "eo:cloud_cover", + "time_offset": "7d" + }, + "type": "raster" + }, + "pre_sentinel2": { + "alias": "sentinel2", + "band_sets": [ + { + "bands": [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B11", + "B12" + ], + "dtype": "uint16" + } + ], + "data_source": { + "cache_dir": "cache/planetary_computer", + "duration": "180d", + "harmonize": true, + "ingest": false, + "name": "rslearn.data_sources.planetary_computer.Sentinel2", + "query_config": { + "max_matches": 4, + "space_mode": "CONTAINS" + }, + "sort_by": "eo:cloud_cover", + "time_offset": "-300d" + }, + "type": "raster" + } + } +} diff --git a/data/forest_loss_driver/20251104/config.yaml b/data/forest_loss_driver/20251104/config.yaml new file mode 100644 index 00000000..5bfb2d5f --- /dev/null +++ b/data/forest_loss_driver/20251104/config.yaml @@ -0,0 +1,122 @@ +model: +# class_path: olmoearth_projects.train.classification_confusion_matrix.CMLightningModule + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.singletask.SingleTaskModel + init_args: + # We use SimpleTimeSeries not for time series but just to concatenate the + # feature maps for pre-forest-loss and post-forest-loss. + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_BASE + patch_size: 4 + image_channels: 48 + image_key: "sentinel2_l2a" + groups: [[0], [1]] + decoder: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 1536 + out_channels: 10 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 + optimizer: + class_path: rslearn.train.optimizer.AdamW + init_args: + lr: 0.0001 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/combined/ + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["pre_sentinel2", "pre_sentinel2.1", "pre_sentinel2.2", "pre_sentinel2.3", "post_sentinel2", "post_sentinel2.1", "post_sentinel2.2", "post_sentinel2.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + load_all_layers: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "new_label" + classes: ["agriculture", "mining", "airstrip", "road", "logging", "burned", "landslide", "hurricane", "river", "none"] + allow_invalid: true + metric_kwargs: + average: "micro" + prob_property: "probs" + skip_unknown_categories: true + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + mode: "center" + size: 64 + image_selectors: + - sentinel2_l2a + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: + - sentinel2_l2a + # To support visualization via image key. + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + sentinel2_l2a: [2, 1, 0] + output_selector: image + train_config: + tags: + split: train + val_config: + tags: + split: val + test_config: + tags: + split: val + predict_config: + skip_targets: true +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 20 + unfreeze_lr_factor: 10 + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_accuracy + mode: max + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: placeholder + output_layer: output +rslp_project: olmoearth_forest_loss_driver +rslp_experiment: forest_loss_driver_00 diff --git a/data/forest_loss_driver/README.md b/data/forest_loss_driver/README.md index ff3f3d50..c649d00a 100644 --- a/data/forest_loss_driver/README.md +++ b/data/forest_loss_driver/README.md @@ -18,7 +18,6 @@ Dataset Versions Dataset Configurations ---------------------- -- config.json: current inference config that uses RGB PNGs. - config_ms.json: corresponds to dataset 20250429, it gets L2A images with all bands stored as GeoTIFF. - config_studio_annotation.json: this is original config used for Brazil+Colombia @@ -26,3 +25,12 @@ Dataset Configurations Studio. It also gets Planet Labs RGB images. - config_multimodal.json: this gets inputs that match what Helios can do, Sentinel-2 + Sentinel-1 + Landsat. + + +Deployment Details +------------------ + +- 20251104: deploy OlmoEarth-v1-FT-ForestLossDriver-Base on Brazil, Peru, and Colombia. + The model uses Sentinel-2 L2A images from Microsoft Planetary Computer. +- 20240912: original deployment trained on Peru only, applying Satlas on Sentinel-2 L1C + RGB PNGs. diff --git a/data/forest_loss_driver/peru_brazil_colombia_model/README.md b/data/forest_loss_driver/peru_brazil_colombia_model/README.md new file mode 100644 index 00000000..1c8fc1cd --- /dev/null +++ b/data/forest_loss_driver/peru_brazil_colombia_model/README.md @@ -0,0 +1,2 @@ +I tried training SatlasPretrain and OlmoEarth-SwinPretrain on forest loss driver but they +did not perform very well so did not use them. diff --git a/data/forest_loss_driver/peru_brazil_colombia_model/model_satlas.yaml b/data/forest_loss_driver/peru_brazil_colombia_model/model_satlas.yaml new file mode 100644 index 00000000..ae2ff3d0 --- /dev/null +++ b/data/forest_loss_driver/peru_brazil_colombia_model/model_satlas.yaml @@ -0,0 +1,123 @@ +model: +# class_path: olmoearth_projects.train.classification_confusion_matrix.CMLightningModule + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.singletask.SingleTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + groups: [[0, 1, 2, 3], [4, 5, 6, 7]] + decoder: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 2048 + out_channels: 10 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 + optimizer: + class_path: rslearn.train.optimizer.AdamW + init_args: + lr: 0.0001 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/combined/ + inputs: + image: + data_type: "raster" + layers: ["pre_sentinel2", "pre_sentinel2.1", "pre_sentinel2.2", "pre_sentinel2.3", "post_sentinel2", "post_sentinel2.1", "post_sentinel2.2", "post_sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + load_all_layers: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "new_label" + classes: ["agriculture", "mining", "airstrip", "road", "logging", "burned", "landslide", "hurricane", "river", "none"] + allow_invalid: true + metric_kwargs: + average: "micro" + prob_property: "probs" + skip_unknown_categories: true + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + mode: "center" + size: 64 + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + num_bands: 9 + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + num_bands: 9 + - class_path: rslearn.train.transforms.flip.Flip + train_config: + tags: + split: train + val_config: + tags: + split: val + test_config: + tags: + split: val + predict_config: + skip_targets: true +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" +# - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze +# init_args: +# module_selector: ["model", "encoder", 0, "encoder", "model"] +# unfreeze_at_epoch: 20 +# unfreeze_lr_factor: 10 + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_accuracy + mode: max + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: placeholder + output_layer: output +rslp_project: 20251103_forest_loss_driver +rslp_experiment: satlas_00 diff --git a/data/forest_loss_driver/peru_brazil_colombia_model/model_swinp.yaml b/data/forest_loss_driver/peru_brazil_colombia_model/model_swinp.yaml new file mode 100644 index 00000000..e38be98a --- /dev/null +++ b/data/forest_loss_driver/peru_brazil_colombia_model/model_swinp.yaml @@ -0,0 +1,113 @@ +model: +# class_path: olmoearth_projects.train.classification_confusion_matrix.CMLightningModule + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.singletask.SingleTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslp.swin_pretrain.model.Model + init_args: + target_resolution_factor: null + image_channels: 48 + backbone_channels: [[4, 128], [8, 256], [16, 512], [32, 1024]] + groups: [[0], [1]] + decoder: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 2048 + out_channels: 10 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 + optimizer: + class_path: rslearn.train.optimizer.AdamW + init_args: + lr: 0.0001 + restore_config: + restore_path: /weka/dfive-default/rslearn-eai/projects/swin_pretrain/pretrain_4mod_1conv/checkpoints/epoch130.ckpt.bak + selector: ["state_dict"] + remap_prefixes: + - ["model.encoder.0.", "encoder.0.encoder."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/combined/ + inputs: + image: + data_type: "raster" + layers: ["pre_sentinel2", "pre_sentinel2.1", "pre_sentinel2.2", "pre_sentinel2.3", "post_sentinel2", "post_sentinel2.1", "post_sentinel2.2", "post_sentinel2.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + load_all_layers: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "new_label" + classes: ["agriculture", "mining", "airstrip", "road", "logging", "burned", "landslide", "hurricane", "river", "none"] + allow_invalid: true + metric_kwargs: + average: "micro" + prob_property: "probs" + skip_unknown_categories: true + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + mode: "center" + size: 64 + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + - class_path: rslearn.train.transforms.flip.Flip + train_config: + tags: + split: train + val_config: + tags: + split: val + test_config: + tags: + split: val + predict_config: + skip_targets: true +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" +# - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze +# init_args: +# module_selector: ["model", "encoder", 0, "encoder", "model"] +# unfreeze_at_epoch: 20 +# unfreeze_lr_factor: 10 + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_accuracy + mode: max + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: placeholder + output_layer: output +rslp_project: 20251103_forest_loss_driver +rslp_experiment: swinp_03 diff --git a/rslp/forest_loss_driver/config/forest_loss_driver_predict_pipeline_config.yaml b/rslp/forest_loss_driver/config/forest_loss_driver_predict_pipeline_config.yaml index cbea4805..1259dc14 100644 --- a/rslp/forest_loss_driver/config/forest_loss_driver_predict_pipeline_config.yaml +++ b/rslp/forest_loss_driver/config/forest_loss_driver_predict_pipeline_config.yaml @@ -2,10 +2,20 @@ extract_alerts_args: index_cache_dir: "file:///tmp/index_cache_dir/" tile_store_dir: "file:///dfive-default/rslearn-eai/datasets/forest_loss_driver/tile_store_root_dir/" workers: 112 - countries: ["PE"] + countries: ["PE", "BR", "CO"] gcs_tiff_filenames: + - "040W_10S_030W_00N.tif" + - "040W_20S_030W_10S.tif" + - "050W_00N_040W_10N.tif" + - "050W_10S_040W_00N.tif" + - "050W_20S_040W_10S.tif" + - "060W_00N_050W_10N.tif" + - "060W_10S_050W_00N.tif" + - "060W_20S_050W_10S.tif" + - "070W_00N_060W_10N.tif" - "070W_10S_060W_00N.tif" - "070W_20S_060W_10S.tif" + - "080W_00N_070W_10N.tif" - "080W_10S_070W_00N.tif" - "080W_20S_070W_10S.tif" select_least_cloudy_images_args: diff --git a/rslp/forest_loss_driver/extract_dataset/__init__.py b/rslp/forest_loss_driver/extract_dataset/__init__.py index 891adb22..c9c46a26 100644 --- a/rslp/forest_loss_driver/extract_dataset/__init__.py +++ b/rslp/forest_loss_driver/extract_dataset/__init__.py @@ -2,6 +2,7 @@ import multiprocessing from dataclasses import dataclass, field +from datetime import timedelta from upath import UPath @@ -103,7 +104,8 @@ class VisLayerMaterializeArgs(MaterializePipelineArgs): apply_windows_args=ApplyWindowsArgs( use_initial_job=True, workers=DEFAULT_VIS_LAYER_WORKERS ), - retry_max_attempts=5, + retry_max_attempts=20, + retry_backoff=timedelta(seconds=30), ), ) ingest_args: IngestArgs = field( @@ -116,6 +118,8 @@ class VisLayerMaterializeArgs(MaterializePipelineArgs): default_factory=lambda: MaterializeArgs( ignore_errors=True, apply_windows_args=ApplyWindowsArgs(workers=DEFAULT_VIS_LAYER_WORKERS), + retry_max_attempts=5, + retry_backoff=timedelta(seconds=30), ), ) diff --git a/rslp/forest_loss_driver/extract_dataset/extract_alerts.py b/rslp/forest_loss_driver/extract_dataset/extract_alerts.py index a5ec0a29..770da000 100644 --- a/rslp/forest_loss_driver/extract_dataset/extract_alerts.py +++ b/rslp/forest_loss_driver/extract_dataset/extract_alerts.py @@ -46,7 +46,7 @@ WEB_MERCATOR_PROJECTION = Projection(WEB_MERCATOR_CRS, PIXEL_SIZE, -PIXEL_SIZE) ANNOTATION_WEBSITE_MERCATOR_OFFSET = 512 * (2**12) -INFERENCE_DATASET_CONFIG = "data/forest_loss_driver/config.json" +INFERENCE_DATASET_CONFIG = "data/forest_loss_driver/20251104/config.json" # Filename used to indicate that alert extraction is done for a given dataset. COMPLETED_FNAME = "extract_alerts_completed" diff --git a/rslp/forest_loss_driver/extract_dataset/least_cloudy_image_selector.py b/rslp/forest_loss_driver/extract_dataset/least_cloudy_image_selector.py index 14806f8e..50886674 100644 --- a/rslp/forest_loss_driver/extract_dataset/least_cloudy_image_selector.py +++ b/rslp/forest_loss_driver/extract_dataset/least_cloudy_image_selector.py @@ -11,11 +11,10 @@ from rslearn.config import RasterFormatConfig, RasterLayerConfig from rslearn.data_sources import Item from rslearn.dataset import Dataset, Window, WindowLayerData -from rslearn.utils.raster_format import load_raster_format +from rslearn.utils.raster_format import SingleImageRasterFormat, load_raster_format from upath import UPath from rslp.log_utils import get_logger -from rslp.utils.fs import copy_files logger = get_logger(__name__) @@ -87,7 +86,8 @@ def select_least_cloudy_images( layer_times[(layer_name, group_idx)] = item.geometry.time_range[0] # Find least cloudy pre and post images. - layer_cloudiness: dict[tuple[str, int], int] = {} + # layer_cloudiness maps from (layer_name, group_idx) => (cloudiness, RGB image). + layer_cloudiness: dict[tuple[str, int], tuple[int, np.ndarray]] = {} for layer_name, group_idx in layer_times.keys(): if not window.is_layer_completed(layer_name, group_idx): continue @@ -106,26 +106,19 @@ def select_least_cloudy_images( raster_dir, window.projection, window.bounds ) + # Get RGB by selecting (B04, B03, B02) and dividing by 10. + rgb_indices = ( + bands.index("B04"), + bands.index("B03"), + bands.index("B02"), + ) + rgb_array = array[rgb_indices, :, :] // 10 + # Use the center crop since that's the most important part. - array = array[:, 32:96, 32:96] - - # Handle differently depending on if we have TCI (RGB) data or the individual - # bands. - if "R" in bands: - assert array.shape[0] == 3 - cloudiness = compute_cloudiness_score(array) - - else: - # Get RGB by selecting (B04, B03, B02) and dividing by 10. - rgb_indices = ( - bands.index("B04"), - bands.index("B03"), - bands.index("B02"), - ) - rgb_array = array[rgb_indices, :, :] // 10 - cloudiness = compute_cloudiness_score(rgb_array) + center_crop = rgb_array[:, 32:96, 32:96] + cloudiness = compute_cloudiness_score(center_crop) - layer_cloudiness[(layer_name, group_idx)] = cloudiness + layer_cloudiness[(layer_name, group_idx)] = (cloudiness, rgb_array) # Determine the least cloudy pre and post images. # We copy those images to a new "best_X" layer. @@ -137,21 +130,32 @@ def select_least_cloudy_images( least_cloudy_times = {} for pre_or_post in ["pre", "post"]: image_list = [ - (layer_name, group_idx, cloudiness) - for (layer_name, group_idx), cloudiness in layer_cloudiness.items() + (layer_name, group_idx, cloudiness, rgb_array) + for (layer_name, group_idx), ( + cloudiness, + rgb_array, + ) in layer_cloudiness.items() if layer_name.startswith(pre_or_post) ] if len(image_list) < min_choices: return # Sort by cloudiness (third element of tuple) so we can pick the least cloudy. image_list.sort(key=lambda t: t[2]) - for idx, (layer_name, group_idx, _) in enumerate(image_list[0:num_outs]): + for idx, (layer_name, group_idx, _, rgb_array) in enumerate( + image_list[0:num_outs] + ): # The layer name for the best images is e.g. "best_pre_0". dst_layer_name = f"best_{pre_or_post}_{idx}" - src_layer_dir = window.get_layer_dir(layer_name, group_idx) - dst_layer_dir = window.get_layer_dir(dst_layer_name) - copy_files(src_layer_dir, dst_layer_dir) + # Write the RGB version of the data. + dst_raster_dir = window.get_raster_dir(dst_layer_name, ["R", "G", "B"]) + SingleImageRasterFormat().encode_raster( + dst_raster_dir, + window.projection, + window.bounds, + np.clip(rgb_array, 0, 255).astype(np.uint8), + ) + window.mark_layer_completed(dst_layer_name) layer_time = layer_times[(layer_name, group_idx)] least_cloudy_times[dst_layer_name] = layer_time.isoformat() diff --git a/rslp/forest_loss_driver/predict_pipeline.py b/rslp/forest_loss_driver/predict_pipeline.py index e875a252..f0dc3125 100644 --- a/rslp/forest_loss_driver/predict_pipeline.py +++ b/rslp/forest_loss_driver/predict_pipeline.py @@ -5,7 +5,7 @@ from rslp.log_utils import get_logger from rslp.utils.rslearn import run_model_predict -MODEL_CFG_FNAME = "data/forest_loss_driver/config.yaml" +MODEL_CFG_FNAME = "data/forest_loss_driver/20251104/config.yaml" logger = get_logger(__name__)