Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(clp-package): Let clp-config objects check against unexpected fields. #676

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions components/clp-py-utils/clp_py_utils/clp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Literal, Optional, Tuple, Union

from dotenv import dotenv_values
from pydantic import BaseModel, PrivateAttr, validator
from pydantic import BaseModel, Extra, PrivateAttr, validator
from strenum import KebabCaseStrEnum, LowercaseStrEnum

from .clp_logging import get_valid_logging_level, is_valid_logging_level
Expand Down Expand Up @@ -57,7 +57,12 @@ class StorageType(LowercaseStrEnum):
VALID_STORAGE_ENGINES = [storage_engine.value for storage_engine in StorageEngine]


class Package(BaseModel):
class BaseModelForbidExtra(BaseModel):
class Config:
extra = Extra.forbid


class Package(BaseModelForbidExtra):
storage_engine: str = "clp"

@validator("storage_engine")
Expand All @@ -70,7 +75,7 @@ def validate_storage_engine(cls, field):
return field


class Database(BaseModel):
class Database(BaseModelForbidExtra):
type: str = "mariadb"
host: str = "localhost"
port: int = 3306
Expand Down Expand Up @@ -174,7 +179,7 @@ def _validate_port(cls, field):
)


class CompressionScheduler(BaseModel):
class CompressionScheduler(BaseModelForbidExtra):
jobs_poll_delay: float = 0.1 # seconds
logging_level: str = "INFO"

Expand All @@ -184,7 +189,7 @@ def validate_logging_level(cls, field):
return field


class QueryScheduler(BaseModel):
class QueryScheduler(BaseModelForbidExtra):
host = "localhost"
port = 7000
jobs_poll_delay: float = 0.1 # seconds
Expand All @@ -209,7 +214,7 @@ def validate_port(cls, field):
return field


class CompressionWorker(BaseModel):
class CompressionWorker(BaseModelForbidExtra):
logging_level: str = "INFO"

@validator("logging_level")
Expand All @@ -218,7 +223,7 @@ def validate_logging_level(cls, field):
return field


class QueryWorker(BaseModel):
class QueryWorker(BaseModelForbidExtra):
logging_level: str = "INFO"

@validator("logging_level")
Expand All @@ -227,7 +232,7 @@ def validate_logging_level(cls, field):
return field


class Redis(BaseModel):
class Redis(BaseModelForbidExtra):
host: str = "localhost"
port: int = 6379
query_backend_database: int = 0
Expand All @@ -242,7 +247,7 @@ def validate_host(cls, field):
return field


class Reducer(BaseModel):
class Reducer(BaseModelForbidExtra):
host: str = "localhost"
base_port: int = 14009
logging_level: str = "INFO"
Expand Down Expand Up @@ -272,7 +277,7 @@ def validate_upsert_interval(cls, field):
return field


class ResultsCache(BaseModel):
class ResultsCache(BaseModelForbidExtra):
host: str = "localhost"
port: int = 27017
db_name: str = "clp-query-results"
Expand Down Expand Up @@ -302,15 +307,15 @@ def get_uri(self):
return f"mongodb://{self.host}:{self.port}/{self.db_name}"


class Queue(BaseModel):
class Queue(BaseModelForbidExtra):
host: str = "localhost"
port: int = 5672

username: Optional[str]
password: Optional[str]


class S3Credentials(BaseModel):
class S3Credentials(BaseModelForbidExtra):
access_key_id: str
secret_access_key: str

Expand All @@ -327,7 +332,7 @@ def validate_secret_access_key(cls, field):
return field


class S3Config(BaseModel):
class S3Config(BaseModelForbidExtra):
region_code: str
bucket: str
key_prefix: str
Expand Down Expand Up @@ -360,7 +365,7 @@ def get_credentials(self) -> Tuple[Optional[str], Optional[str]]:
return self.credentials.access_key_id, self.credentials.secret_access_key


class FsStorage(BaseModel):
class FsStorage(BaseModelForbidExtra):
type: Literal[StorageType.FS.value] = StorageType.FS.value
directory: pathlib.Path

Expand All @@ -379,7 +384,7 @@ def dump_to_primitive_dict(self):
return d


class S3Storage(BaseModel):
class S3Storage(BaseModelForbidExtra):
type: Literal[StorageType.S3.value] = StorageType.S3.value
staging_directory: pathlib.Path
s3_config: S3Config
Expand Down Expand Up @@ -437,7 +442,7 @@ def _set_directory_for_storage_config(
raise NotImplementedError(f"storage.type {storage_type} is not supported")


class ArchiveOutput(BaseModel):
class ArchiveOutput(BaseModelForbidExtra):
storage: Union[ArchiveFsStorage, ArchiveS3Storage] = ArchiveFsStorage()
target_archive_size: int = 256 * 1024 * 1024 # 256 MB
target_dictionaries_size: int = 32 * 1024 * 1024 # 32 MB
Expand Down Expand Up @@ -480,7 +485,7 @@ def dump_to_primitive_dict(self):
return d


class StreamOutput(BaseModel):
class StreamOutput(BaseModelForbidExtra):
storage: Union[StreamFsStorage, StreamS3Storage] = StreamFsStorage()
target_uncompressed_size: int = 128 * 1024 * 1024

Expand All @@ -502,7 +507,7 @@ def dump_to_primitive_dict(self):
return d


class WebUi(BaseModel):
class WebUi(BaseModelForbidExtra):
host: str = "localhost"
port: int = 4000
logging_level: str = "INFO"
Expand All @@ -523,7 +528,7 @@ def validate_logging_level(cls, field):
return field


class LogViewerWebUi(BaseModel):
class LogViewerWebUi(BaseModelForbidExtra):
host: str = "localhost"
port: int = 3000

Expand All @@ -538,7 +543,7 @@ def validate_port(cls, field):
return field


class CLPConfig(BaseModel):
class CLPConfig(BaseModelForbidExtra):
execution_container: Optional[str] = None

input_logs_directory: pathlib.Path = pathlib.Path("/")
Expand Down Expand Up @@ -678,7 +683,7 @@ def dump_to_primitive_dict(self):
return d


class WorkerConfig(BaseModel):
class WorkerConfig(BaseModelForbidExtra):
package: Package = Package()
archive_output: ArchiveOutput = ArchiveOutput()
data_directory: pathlib.Path = CLPConfig().data_directory
Expand Down
Loading