Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sameer Sharma committed Aug 5, 2024
1 parent 2221de1 commit a6e1692
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 46 deletions.
26 changes: 14 additions & 12 deletions service_configuration_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from socket import SO_REUSEADDR
from socket import socket
from socket import SOL_SOCKET
from typing import Dict
from typing import Mapping
from typing import Tuple
from typing import Literal
from typing import Dict

import yaml
from typing_extensions import Literal

DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml'
POD_TEMPLATE_PATH = '/nail/tmp/spark-pt-{file_uuid}.yaml'
Expand All @@ -26,7 +26,7 @@
EPHEMERAL_PORT_START = 49152
EPHEMERAL_PORT_END = 65535

MEM_MULTIPLIER = {"k": 1024, "m": 1024**2, "g": 1024**3, "t": 1024**4}
MEM_MULTIPLIER = {'k': 1024, 'm': 1024**2, 'g': 1024**3, 't': 1024**4}

SPARK_DRIVER_MEM_DEFAULT_MB = 2048
SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT = 0.1
Expand Down Expand Up @@ -157,7 +157,7 @@ def get_runtime_env() -> str:
return 'unknown'


def get_spark_memory_in_unit(mem: str, unit: Literal["k", "m", "g", "t"]) -> float:
def get_spark_memory_in_unit(mem: str, unit: Literal['k', 'm', 'g', 't']) -> float:
"""
Converts Spark memory to the desired unit.
mem is the same format as JVM memory strings: just number or number followed by 'k', 'm', 'g' or 't'.
Expand All @@ -170,20 +170,20 @@ def get_spark_memory_in_unit(mem: str, unit: Literal["k", "m", "g", "t"]) -> flo
try:
memory_bytes = float(mem[:-1]) * MEM_MULTIPLIER[mem[-1]]
except (ValueError, IndexError):
print(f"Unable to parse memory value {mem}.")
print(f'Unable to parse memory value {mem}.')
raise
memory_unit = memory_bytes / MEM_MULTIPLIER[unit]
return memory_unit
return round(memory_unit, 5)


def get_spark_driver_memory_mb(spark_conf: Dict[str, str]) -> float:
"""
Returns the Spark driver memory in MB.
"""
# spark_conf is expected to have "spark.driver.memory" since it is a mandatory default from srv-configs.
driver_mem = spark_conf["spark.driver.memory"]
driver_mem = spark_conf['spark.driver.memory']
try:
return get_spark_memory_in_unit(str(driver_mem), "m")
return get_spark_memory_in_unit(str(driver_mem), 'm')
except (ValueError, IndexError):
return SPARK_DRIVER_MEM_DEFAULT_MB

Expand All @@ -194,15 +194,17 @@ def get_spark_driver_memory_overhead_mb(spark_conf: Dict[str, str]) -> float:
"""
# Use spark.driver.memoryOverhead if it is set.
try:
driver_mem_overhead = spark_conf["spark.driver.memoryOverhead"]
driver_mem_overhead = spark_conf['spark.driver.memoryOverhead']
try:
# spark.driver.memoryOverhead default unit is MB
driver_mem_overhead_mb = float(driver_mem_overhead)
except ValueError:
driver_mem_overhead_mb = get_spark_memory_in_unit(str(driver_mem_overhead), "m")
driver_mem_overhead_mb = get_spark_memory_in_unit(str(driver_mem_overhead), 'm')
# Calculate spark.driver.memoryOverhead based on spark.driver.memory and spark.driver.memoryOverheadFactor.
except Exception:
driver_mem_mb = get_spark_driver_memory_mb(spark_conf)
driver_mem_overhead_factor = float(spark_conf.get("spark.driver.memoryOverheadFactor", SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT))
driver_mem_overhead_factor = float(
spark_conf.get('spark.driver.memoryOverheadFactor', SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT),
)
driver_mem_overhead_mb = driver_mem_mb * driver_mem_overhead_factor
return driver_mem_overhead_mb
return round(driver_mem_overhead_mb, 5)
76 changes: 42 additions & 34 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
from socket import SO_REUSEADDR
from socket import socket as Socket
from socket import SOL_SOCKET
from typing import cast
from unittest import mock
from unittest.mock import mock_open
from unittest.mock import patch

import pytest
from typing_extensions import Literal

from service_configuration_lib import utils
from service_configuration_lib.utils import ephemeral_port_reserve_range
from service_configuration_lib.utils import LOCALHOST

from typing import cast
from typing import Literal


MOCK_ENV_NAME = 'mock_env_name'

Expand Down Expand Up @@ -80,41 +79,47 @@ def test_generate_pod_template_path(hex_value):
@pytest.mark.parametrize(
'mem_str,unit_str,expected_mem',
(
('13425m', "m", 13425), # Simple case
('138412032', "m", 132), # Bytes to MB
('13425m', 'm', 13425), # Simple case
('138412032', 'm', 132), # Bytes to MB
('65536k', 'g', 0.0625), # KB to GB
('1t', 'g', 1024), # TB to GB
('1.5g', 'm', 1536), # GB to MB with decimal
('2048k', 'm', 2), # KB to MB
('0.5g', 'k', 524288), # GB to KB
('1024m', 't', 0.001), # MB to TB
('32768m', 't', 0.03125), # MB to TB
('1.5t', 'm', 1572864), # TB to MB with decimal
),
)
def test_get_spark_memory_in_unit(mem_str, unit_str, expected_mem):
assert expected_mem == utils.get_spark_memory_in_unit(mem_str, cast(Literal["k", "m", "g", "t"], unit_str))
assert expected_mem == utils.get_spark_memory_in_unit(mem_str, cast(Literal['k', 'm', 'g', 't'], unit_str))


def test_get_spark_memory_in_unit_exceptions():
with pytest.raises(ValueError):
utils.get_spark_memory_in_unit("1x", "k")
with pytest.raises(IndexError):
utils.get_spark_memory_in_unit("1024mb", "m")
@pytest.mark.parametrize(
'mem_str,unit_str',
[
('invalid', 'm'),
('1024mb', 'g'),
],
)
def test_get_spark_memory_in_unit_exceptions(mem_str, unit_str):
with pytest.raises((ValueError, IndexError)):
utils.get_spark_memory_in_unit(mem_str, cast(Literal['k', 'm', 'g', 't'], unit_str))


@pytest.mark.parametrize(
'spark_conf,expected_mem',
[
({"spark.driver.memory": "13425m"}, 13425), # Simple case
({"spark.driver.memory": "138412032"}, 132), # Bytes to MB
({"spark.driver.memory": "65536k"}, 64), # KB to MB
({"spark.driver.memory": "1g"}, 1024), # GB to MB
({"spark.driver.memory": "invalid"}, utils.SPARK_DRIVER_MEM_DEFAULT_MB), # Invalid case
({"spark.driver.memory": "1.5g"}, 1536), # GB to MB with decimal
({"spark.driver.memory": "2048k"}, 2), # KB to MB
({"spark.driver.memory": "0.5t"}, 524288), # TB to MB
({"spark.driver.memory": "1024m"}, 1024), # MB to MB
({"spark.driver.memory": "1.5t"}, 1572864), # TB to MB with decimal
]
({'spark.driver.memory': '13425m'}, 13425), # Simple case
({'spark.driver.memory': '138412032'}, 132), # Bytes to MB
({'spark.driver.memory': '65536k'}, 64), # KB to MB
({'spark.driver.memory': '1g'}, 1024), # GB to MB
({'spark.driver.memory': 'invalid'}, utils.SPARK_DRIVER_MEM_DEFAULT_MB), # Invalid case
({'spark.driver.memory': '1.5g'}, 1536), # GB to MB with decimal
({'spark.driver.memory': '2048k'}, 2), # KB to MB
({'spark.driver.memory': '0.5t'}, 524288), # TB to MB
({'spark.driver.memory': '1024m'}, 1024), # MB to MB
({'spark.driver.memory': '1.5t'}, 1572864), # TB to MB with decimal
],
)
def test_get_spark_driver_memory_mb(spark_conf, expected_mem):
assert expected_mem == utils.get_spark_driver_memory_mb(spark_conf)
Expand All @@ -123,22 +128,25 @@ def test_get_spark_driver_memory_mb(spark_conf, expected_mem):
@pytest.mark.parametrize(
'spark_conf,expected_mem_overhead',
[
({"spark.driver.memoryOverhead": "1024"}, 1024), # Simple case
({"spark.driver.memoryOverhead": "1g"}, 1024 * 1024), # GB to MB
({"spark.driver.memory": "10240m", "spark.driver.memoryOverheadFactor": "0.2"}, 2048), # Custom overhead factor
({"spark.driver.memory": "10240m"}, 1024), # Using default overhead factor
({"spark.driver.memory": "invalid"}, utils.SPARK_DRIVER_MEM_DEFAULT_MB * utils.SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT),
({'spark.driver.memoryOverhead': '1024'}, 1024), # Simple case
({'spark.driver.memoryOverhead': '1g'}, 1024), # GB to MB
({'spark.driver.memory': '10240m', 'spark.driver.memoryOverheadFactor': '0.2'}, 2048), # Custom OverheadFactor
({'spark.driver.memory': '10240m'}, 1024), # Using default overhead factor
(
{'spark.driver.memory': 'invalid'},
utils.SPARK_DRIVER_MEM_DEFAULT_MB * utils.SPARK_DRIVER_MEM_OVERHEAD_FACTOR_DEFAULT,
),
# Invalid case
({"spark.driver.memoryOverhead": "1.5g"}, 1536 * 1024), # GB to MB with decimal
({"spark.driver.memory": "2048k", "spark.driver.memoryOverheadFactor": "0.05"}, 0.1),
({'spark.driver.memoryOverhead': '1.5g'}, 1536), # GB to MB with decimal
({'spark.driver.memory': '2048k', 'spark.driver.memoryOverheadFactor': '0.05'}, 0.1),
# KB to MB with custom factor
({"spark.driver.memory": "0.5t", "spark.driver.memoryOverheadFactor": "0.15"}, 78643.2),
({'spark.driver.memory': '0.5t', 'spark.driver.memoryOverheadFactor': '0.15'}, 78643.2),
# TB to MB with custom factor
({"spark.driver.memory": "1024m", "spark.driver.memoryOverheadFactor": "0.25"}, 256),
({'spark.driver.memory': '1024m', 'spark.driver.memoryOverheadFactor': '0.25'}, 256),
# MB to MB with custom factor
({"spark.driver.memory": "1.5t", "spark.driver.memoryOverheadFactor": "0.05"}, 78643.2),
({'spark.driver.memory': '1.5t', 'spark.driver.memoryOverheadFactor': '0.05'}, 78643.2),
# TB to MB with custom factor
]
],
)
def test_get_spark_driver_memory_overhead_mb(spark_conf, expected_mem_overhead):
assert expected_mem_overhead == utils.get_spark_driver_memory_overhead_mb(spark_conf)
Expand Down

0 comments on commit a6e1692

Please sign in to comment.