11"""Evaluation adapter for TerraMind."""
22
3+ import os
4+
35import torch
46from rslearn .models .faster_rcnn import FasterRCNN
57from rslearn .models .multitask import MultiTaskModel
@@ -27,11 +29,21 @@ def get_model(
2729 task_timesteps : int = 1 ,
2830) -> torch .nn .Module :
2931 """Get appropriate TerraMind model."""
32+ model_id = os .environ ["EVAL_ADAPTER_MODEL_ID" ]
33+ if model_id == "terramind" :
34+ terramind_size = TerramindSize .BASE
35+ embedding_size = 768
36+ elif model_id == "terramind_large" :
37+ terramind_size = TerramindSize .LARGE
38+ embedding_size = 1024
39+ else :
40+ raise ValueError (f"unknown terramind model ID { model_id } " )
41+
3042 if task_type == "segment" :
3143 decoders = dict (
3244 eval_task = [
3345 UNetDecoder (
34- in_channels = [[16 , 768 ]],
46+ in_channels = [[16 , embedding_size ]],
3547 out_channels = task_channels ,
3648 conv_layers_per_resolution = 2 ,
3749 num_channels = {16 : 512 , 8 : 512 , 4 : 512 , 2 : 256 , 1 : 128 },
@@ -43,7 +55,7 @@ def get_model(
4355 decoders = dict (
4456 eval_task = [
4557 SegmentationPoolingDecoder (
46- in_channels = 768 ,
58+ in_channels = embedding_size ,
4759 out_channels = task_channels ,
4860 ),
4961 SegmentationHead (),
@@ -54,7 +66,7 @@ def get_model(
5466 eval_task = [
5567 FasterRCNN (
5668 downsample_factors = [16 ],
57- num_channels = 768 ,
69+ num_channels = embedding_size ,
5870 num_classes = task_channels ,
5971 anchor_sizes = [[32 ]],
6072 )
@@ -64,7 +76,7 @@ def get_model(
6476 decoders = dict (
6577 eval_task = [
6678 PoolingDecoder (
67- in_channels = 768 ,
79+ in_channels = embedding_size ,
6880 out_channels = task_channels ,
6981 num_conv_layers = 1 ,
7082 num_fc_layers = 1 ,
@@ -76,7 +88,7 @@ def get_model(
7688 decoders = dict (
7789 eval_task = [
7890 PoolingDecoder (
79- in_channels = 768 ,
91+ in_channels = embedding_size ,
8092 out_channels = task_channels ,
8193 num_conv_layers = 1 ,
8294 num_fc_layers = 1 ,
@@ -102,7 +114,7 @@ def get_model(
102114 SimpleTimeSeries (
103115 encoder = SimpleTimeSeries (
104116 encoder = Terramind (
105- model_size = TerramindSize . BASE ,
117+ model_size = terramind_size ,
106118 modalities = modalities ,
107119 ),
108120 image_keys = image_keys ,
@@ -115,7 +127,7 @@ def get_model(
115127 decoders = dict (
116128 eval_task = [
117129 PoolingDecoder (
118- in_channels = 768 * 2 ,
130+ in_channels = embedding_size * 2 ,
119131 out_channels = task_channels ,
120132 num_conv_layers = 1 ,
121133 num_fc_layers = 1 ,
@@ -129,7 +141,7 @@ def get_model(
129141 encoder = [
130142 SimpleTimeSeries (
131143 encoder = Terramind (
132- model_size = TerramindSize . BASE ,
144+ model_size = terramind_size ,
133145 modalities = modalities ,
134146 ),
135147 image_keys = image_keys ,
0 commit comments