Skip to content

Commit 3c34d6a

Browse files
authored
Merge pull request #257 from allenai/favyen/20251105-olmoearth-evals
Add CROMA/Terramind/OlmoEarth Large eval + randomly initialized OlmoEarth Base
2 parents 74d48a3 + 147166b commit 3c34d6a

File tree

8 files changed

+76
-20
lines changed

8 files changed

+76
-20
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
trainer:
2+
callbacks+:
3+
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze
4+
init_args:
5+
module_selector: ["model", "model", "encoder", 0, "encoder"]
6+
unfreeze_at_epoch: 20
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
trainer:
2+
callbacks+:
3+
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze
4+
init_args:
5+
module_selector: ["model", "model", "encoder", 0]
6+
unfreeze_at_epoch: 20
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
trainer:
2+
callbacks+:
3+
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze
4+
init_args:
5+
module_selector: ["model", "model", "encoder", 0]
6+
unfreeze_at_epoch: 20
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
trainer:
2+
callbacks+:
3+
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze
4+
init_args:
5+
module_selector: ["model", "model", "encoder", 0, "encoder"]
6+
unfreeze_at_epoch: 20

rslp/olmoearth_evals/croma.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Evaluation adapter for CROMA."""
22

3+
import os
4+
35
import torch
46
from rslearn.models.croma import Croma, CromaModality, CromaNormalize, CromaSize
57
from rslearn.models.faster_rcnn import FasterRCNN
@@ -27,11 +29,21 @@ def get_model(
2729
task_timesteps: int = 1,
2830
) -> torch.nn.Module:
2931
"""Get appropriate CROMA model."""
32+
model_id = os.environ["EVAL_ADAPTER_MODEL_ID"]
33+
if model_id == "croma":
34+
croma_size = CromaSize.BASE
35+
embedding_size = 768
36+
elif model_id == "croma_large":
37+
croma_size = CromaSize.LARGE
38+
embedding_size = 1024
39+
else:
40+
raise ValueError(f"unknown croma model ID {model_id}")
41+
3042
if task_type == "segment":
3143
decoders = dict(
3244
eval_task=[
3345
UNetDecoder(
34-
in_channels=[[8, 768]],
46+
in_channels=[[8, embedding_size]],
3547
out_channels=task_channels,
3648
conv_layers_per_resolution=2,
3749
num_channels={8: 512, 4: 512, 2: 256, 1: 128},
@@ -43,7 +55,7 @@ def get_model(
4355
decoders = dict(
4456
eval_task=[
4557
SegmentationPoolingDecoder(
46-
in_channels=768,
58+
in_channels=embedding_size,
4759
out_channels=task_channels,
4860
),
4961
SegmentationHead(),
@@ -54,7 +66,7 @@ def get_model(
5466
eval_task=[
5567
FasterRCNN(
5668
downsample_factors=[8],
57-
num_channels=768,
69+
num_channels=embedding_size,
5870
num_classes=task_channels,
5971
anchor_sizes=[[32]],
6072
)
@@ -64,7 +76,7 @@ def get_model(
6476
decoders = dict(
6577
eval_task=[
6678
PoolingDecoder(
67-
in_channels=768,
79+
in_channels=embedding_size,
6880
out_channels=task_channels,
6981
num_conv_layers=1,
7082
num_fc_layers=1,
@@ -76,7 +88,7 @@ def get_model(
7688
decoders = dict(
7789
eval_task=[
7890
PoolingDecoder(
79-
in_channels=768,
91+
in_channels=embedding_size,
8092
out_channels=task_channels,
8193
num_conv_layers=1,
8294
num_fc_layers=1,
@@ -108,7 +120,7 @@ def get_model(
108120
SimpleTimeSeries(
109121
encoder=SimpleTimeSeries(
110122
encoder=Croma(
111-
size=CromaSize.BASE,
123+
size=croma_size,
112124
modality=modality,
113125
image_resolution=input_size,
114126
),
@@ -122,7 +134,7 @@ def get_model(
122134
decoders=dict(
123135
eval_task=[
124136
PoolingDecoder(
125-
in_channels=768 * 2,
137+
in_channels=embedding_size * 2,
126138
out_channels=task_channels,
127139
num_conv_layers=1,
128140
num_fc_layers=1,
@@ -136,7 +148,7 @@ def get_model(
136148
encoder=[
137149
SimpleTimeSeries(
138150
encoder=Croma(
139-
size=CromaSize.BASE,
151+
size=croma_size,
140152
modality=modality,
141153
image_resolution=input_size,
142154
),

rslp/olmoearth_evals/eval_adapter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,20 @@
2525
"clay": clay,
2626
"copernicusfm": copernicusfm,
2727
"croma": croma,
28+
"croma_large": croma,
2829
"dinov3": dinov3,
2930
"galileo": galileo,
3031
"olmoearth": olmoearth,
31-
"olmoearth_tiny": olmoearth,
3232
"olmoearth_nano": olmoearth,
33+
"olmoearth_tiny": olmoearth,
34+
"olmoearth_large": olmoearth,
35+
"olmoearth_random": olmoearth,
3336
"panopticon": panopticon,
3437
"presto": presto,
3538
"prithvi": prithvi,
3639
"satlaspretrain": satlaspretrain,
3740
"terramind": terramind,
41+
"terramind_large": terramind,
3842
"aef": aef,
3943
}
4044

rslp/olmoearth_evals/olmoearth.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@ def get_model(
3232
) -> torch.nn.Module:
3333
"""Get appropriate OlmoEarth model."""
3434
model_id = os.environ["EVAL_ADAPTER_MODEL_ID"]
35-
if model_id == "olmoearth":
35+
if model_id in ["olmoearth", "olmoearth_random"]:
3636
olmoearth_model_id = ModelID.OLMOEARTH_V1_BASE
37-
elif model_id == "olmoearth_tiny":
38-
olmoearth_model_id = ModelID.OLMOEARTH_V1_TINY
3937
elif model_id == "olmoearth_nano":
4038
olmoearth_model_id = ModelID.OLMOEARTH_V1_NANO
39+
elif model_id == "olmoearth_tiny":
40+
olmoearth_model_id = ModelID.OLMOEARTH_V1_TINY
41+
elif model_id == "olmoearth_large":
42+
olmoearth_model_id = ModelID.OLMOEARTH_V1_LARGE
4143
else:
4244
raise ValueError(f"unknown olmoearth model ID {model_id}")
4345

@@ -124,6 +126,7 @@ def get_model(
124126
encoder=OlmoEarth(
125127
model_id=olmoearth_model_id,
126128
patch_size=4,
129+
random_initialization=model_id == "olmoearth_random",
127130
),
128131
image_channels=12 * 4,
129132
image_key="sentinel2_l2a",
@@ -148,6 +151,7 @@ def get_model(
148151
OlmoEarth(
149152
model_id=olmoearth_model_id,
150153
patch_size=4,
154+
random_initialization=model_id == "olmoearth_random",
151155
),
152156
],
153157
decoders=decoders,

rslp/olmoearth_evals/terramind.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Evaluation adapter for TerraMind."""
22

3+
import os
4+
35
import torch
46
from rslearn.models.faster_rcnn import FasterRCNN
57
from rslearn.models.multitask import MultiTaskModel
@@ -27,11 +29,21 @@ def get_model(
2729
task_timesteps: int = 1,
2830
) -> torch.nn.Module:
2931
"""Get appropriate TerraMind model."""
32+
model_id = os.environ["EVAL_ADAPTER_MODEL_ID"]
33+
if model_id == "terramind":
34+
terramind_size = TerramindSize.BASE
35+
embedding_size = 768
36+
elif model_id == "terramind_large":
37+
terramind_size = TerramindSize.LARGE
38+
embedding_size = 1024
39+
else:
40+
raise ValueError(f"unknown terramind model ID {model_id}")
41+
3042
if task_type == "segment":
3143
decoders = dict(
3244
eval_task=[
3345
UNetDecoder(
34-
in_channels=[[16, 768]],
46+
in_channels=[[16, embedding_size]],
3547
out_channels=task_channels,
3648
conv_layers_per_resolution=2,
3749
num_channels={16: 512, 8: 512, 4: 512, 2: 256, 1: 128},
@@ -43,7 +55,7 @@ def get_model(
4355
decoders = dict(
4456
eval_task=[
4557
SegmentationPoolingDecoder(
46-
in_channels=768,
58+
in_channels=embedding_size,
4759
out_channels=task_channels,
4860
),
4961
SegmentationHead(),
@@ -54,7 +66,7 @@ def get_model(
5466
eval_task=[
5567
FasterRCNN(
5668
downsample_factors=[16],
57-
num_channels=768,
69+
num_channels=embedding_size,
5870
num_classes=task_channels,
5971
anchor_sizes=[[32]],
6072
)
@@ -64,7 +76,7 @@ def get_model(
6476
decoders = dict(
6577
eval_task=[
6678
PoolingDecoder(
67-
in_channels=768,
79+
in_channels=embedding_size,
6880
out_channels=task_channels,
6981
num_conv_layers=1,
7082
num_fc_layers=1,
@@ -76,7 +88,7 @@ def get_model(
7688
decoders = dict(
7789
eval_task=[
7890
PoolingDecoder(
79-
in_channels=768,
91+
in_channels=embedding_size,
8092
out_channels=task_channels,
8193
num_conv_layers=1,
8294
num_fc_layers=1,
@@ -102,7 +114,7 @@ def get_model(
102114
SimpleTimeSeries(
103115
encoder=SimpleTimeSeries(
104116
encoder=Terramind(
105-
model_size=TerramindSize.BASE,
117+
model_size=terramind_size,
106118
modalities=modalities,
107119
),
108120
image_keys=image_keys,
@@ -115,7 +127,7 @@ def get_model(
115127
decoders=dict(
116128
eval_task=[
117129
PoolingDecoder(
118-
in_channels=768 * 2,
130+
in_channels=embedding_size * 2,
119131
out_channels=task_channels,
120132
num_conv_layers=1,
121133
num_fc_layers=1,
@@ -129,7 +141,7 @@ def get_model(
129141
encoder=[
130142
SimpleTimeSeries(
131143
encoder=Terramind(
132-
model_size=TerramindSize.BASE,
144+
model_size=terramind_size,
133145
modalities=modalities,
134146
),
135147
image_keys=image_keys,

0 commit comments

Comments
 (0)