From 82a71a39f98a2ed0c41e063732d2f1d189840c09 Mon Sep 17 00:00:00 2001 From: Kurt Schwehr Date: Wed, 31 Jul 2024 12:34:35 -0700 Subject: [PATCH] model_test.py: Refactor test data. PiperOrigin-RevId: 658104296 --- python/ee/tests/model_test.py | 217 +++++++++++----------------------- 1 file changed, 67 insertions(+), 150 deletions(-) diff --git a/python/ee/tests/model_test.py b/python/ee/tests/model_test.py index 03ee8825b..6125e443e 100644 --- a/python/ee/tests/model_test.py +++ b/python/ee/tests/model_test.py @@ -20,6 +20,59 @@ def make_expression_graph( } +def make_override_expression(key: str, pixel_type: str) -> Dict[str, Any]: + return { + 'dictionaryValue': { + 'values': { + key: { + 'functionInvocationValue': { + 'functionName': 'PixelType', + 'arguments': { + 'precision': { + 'functionInvocationValue': { + 'functionName': pixel_type, + 'arguments': {}, + } + } + }, + } + } + } + } + } + + +def make_type_expression( + key: str, pixel_type: str, dimensions: int +) -> Dict[str, Any]: + return { + 'dictionaryValue': { + 'values': { + key: { + 'dictionaryValue': { + 'values': { + 'dimensions': {'constantValue': dimensions}, + 'type': { + 'functionInvocationValue': { + 'functionName': 'PixelType', + 'arguments': { + 'precision': { + 'functionInvocationValue': { + 'functionName': pixel_type, + 'arguments': {}, + } + } + }, + } + }, + } + } + } + } + } + } + + class ModelTest(apitestcase.ApiTestCase): def test_serialize(self): @@ -88,83 +141,15 @@ def test_from_ai_platform_predictor(self): 'inputProperties': {'constantValue': input_properties}, 'inputShapes': {'constantValue': input_shapes}, 'inputTileSize': {'constantValue': input_tile_size}, - 'inputTypeOverride': { - 'dictionaryValue': { - 'values': { - 'c': { - 'functionInvocationValue': { - 'functionName': 'PixelType', - 'arguments': { - 'precision': { - 'functionInvocationValue': { - 'functionName': 'PixelType.int8', - 'arguments': {}, - } - } - }, - } - } - } - } - }, + 'inputTypeOverride': make_override_expression( + 'c', 'PixelType.int8' + ), 'modelName': {'constantValue': model_name}, - 'outputBands': { - 'dictionaryValue': { - 'values': { - 'e': { - 'dictionaryValue': { - 'values': { - 'dimensions': {'constantValue': 10}, - 'type': { - 'functionInvocationValue': { - 'functionName': 'PixelType', - 'arguments': { - 'precision': { - 'functionInvocationValue': { - 'functionName': ( - 'PixelType.int16' - ), - 'arguments': {}, - } - } - }, - } - }, - } - } - } - } - } - }, + 'outputBands': make_type_expression('e', 'PixelType.int16', 10), 'outputMultiplier': {'constantValue': output_multiplier}, - 'outputProperties': { - 'dictionaryValue': { - 'values': { - 'f': { - 'dictionaryValue': { - 'values': { - 'dimensions': {'constantValue': 11}, - 'type': { - 'functionInvocationValue': { - 'functionName': 'PixelType', - 'arguments': { - 'precision': { - 'functionInvocationValue': { - 'functionName': ( - 'PixelType.int32' - ), - 'arguments': {}, - } - } - }, - } - }, - } - } - } - } - } - }, + 'outputProperties': make_type_expression( + 'f', 'PixelType.int32', 11 + ), 'outputTileSize': {'constantValue': output_tile_size}, 'proj': { 'functionInvocationValue': { @@ -249,83 +234,15 @@ def test_from_vertex_ai(self): 'inputProperties': {'constantValue': input_properties}, 'inputShapes': {'constantValue': input_shapes}, 'inputTileSize': {'constantValue': input_tile_size}, - 'inputTypeOverride': { - 'dictionaryValue': { - 'values': { - 'c': { - 'functionInvocationValue': { - 'functionName': 'PixelType', - 'arguments': { - 'precision': { - 'functionInvocationValue': { - 'functionName': 'PixelType.int8', - 'arguments': {}, - } - } - }, - } - } - } - } - }, + 'inputTypeOverride': make_override_expression( + 'c', 'PixelType.int8' + ), 'maxPayloadBytes': {'constantValue': max_payload_bytes}, - 'outputBands': { - 'dictionaryValue': { - 'values': { - 'e': { - 'dictionaryValue': { - 'values': { - 'dimensions': {'constantValue': 10}, - 'type': { - 'functionInvocationValue': { - 'functionName': 'PixelType', - 'arguments': { - 'precision': { - 'functionInvocationValue': { - 'functionName': ( - 'PixelType.int16' - ), - 'arguments': {}, - } - } - }, - } - }, - } - } - } - } - } - }, + 'outputBands': make_type_expression('e', 'PixelType.int16', 10), 'outputMultiplier': {'constantValue': output_multiplier}, - 'outputProperties': { - 'dictionaryValue': { - 'values': { - 'f': { - 'dictionaryValue': { - 'values': { - 'dimensions': {'constantValue': 11}, - 'type': { - 'functionInvocationValue': { - 'functionName': 'PixelType', - 'arguments': { - 'precision': { - 'functionInvocationValue': { - 'functionName': ( - 'PixelType.int32' - ), - 'arguments': {}, - } - } - }, - } - }, - } - } - } - } - } - }, + 'outputProperties': make_type_expression( + 'f', 'PixelType.int32', 11 + ), 'outputTileSize': {'constantValue': output_tile_size}, 'payloadFormat': {'constantValue': payload_format}, 'proj': {