Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add taint to user nodes #2605

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
82 changes: 61 additions & 21 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,33 @@ class ExistingInputVars(schema.Base):
kube_context: str


class DigitalOceanNodeGroup(schema.Base):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate class

class NodeGroup(schema.Base):
instance: str
min_nodes: int
max_nodes: int
min_nodes: Annotated[int, Field(ge=0)] = 0
max_nodes: Annotated[int, Field(ge=1)] = 1
taints: Optional[List[schema.Taint]] = []

@field_validator("taints", mode="before")
def validate_taint_strings(cls, value: List[str | schema.Taint]):
TAINT_STR_REGEX = re.compile(r"(\w+)=(\w+):(\w+)")
parsed_taints = []
for taint in value:
if not isinstance(taint, (str, schema.Taint)):
raise ValueError(
f"Unable to parse type: {type(taint)} as taint. Must be a string or Taint object."
)

if isinstance(taint, schema.Taint):
parsed_taint = taint
elif isinstance(taint, str):
match = TAINT_STR_REGEX.match(taint)
if not match:
raise ValueError(f"Invalid taint string: {taint}")
key, value, effect = match.groups()
parsed_taint = schema.Taint(key=key, value=value, effect=effect)
parsed_taints.append(parsed_taint)

return parsed_taints


class DigitalOceanInputVars(schema.Base):
Expand All @@ -55,7 +78,7 @@ class DigitalOceanInputVars(schema.Base):
region: str
tags: List[str]
kubernetes_version: str
node_groups: Dict[str, DigitalOceanNodeGroup]
node_groups: Dict[str, "DigitalOceanNodeGroup"]
kubeconfig_filename: str = get_kubeconfig_filename()


Expand All @@ -64,10 +87,26 @@ class GCPNodeGroupInputVars(schema.Base):
instance_type: str
min_size: int
max_size: int
node_taints: List[dict]
labels: Dict[str, str]
preemptible: bool
guest_accelerators: List["GCPGuestAccelerator"]

@field_validator("node_taints", mode="before")
def convert_taints(cls, value: Optional[List[schema.Taint]]):
return [
dict(
key=taint.key,
value=taint.value,
effect={
schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE",
schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE",
schema.TaintEffectEnum.NoExecute: "NO_EXECUTE",
}[taint.effect],
)
for taint in value
]


class GCPPrivateClusterConfig(schema.Base):
enable_private_nodes: bool
Expand Down Expand Up @@ -225,16 +264,14 @@ class KeyValueDict(schema.Base):
value: str


class DigitalOceanNodeGroup(schema.Base):
class DigitalOceanNodeGroup(NodeGroup):
"""Representation of a node group with Digital Ocean

- Kubernetes limits: https://docs.digitalocean.com/products/kubernetes/details/limits/
- Available instance types: https://slugs.do-api.dev/
"""

instance: str
min_nodes: Annotated[int, Field(ge=1)] = 1
max_nodes: Annotated[int, Field(ge=1)] = 1


DEFAULT_DO_NODE_GROUPS = {
Expand Down Expand Up @@ -319,19 +356,26 @@ class GCPGuestAccelerator(schema.Base):
count: Annotated[int, Field(ge=1)] = 1


class GCPNodeGroup(schema.Base):
instance: str
min_nodes: Annotated[int, Field(ge=0)] = 0
max_nodes: Annotated[int, Field(ge=1)] = 1
class GCPNodeGroup(NodeGroup):
preemptible: bool = False
labels: Dict[str, str] = {}
guest_accelerators: List[GCPGuestAccelerator] = []


DEFAULT_GCP_NODE_GROUPS = {
"general": GCPNodeGroup(instance="e2-standard-8", min_nodes=1, max_nodes=1),
"user": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5),
"worker": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5),
"user": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
taints=[schema.Taint(key="dedicated", value="user", effect="NoSchedule")],
),
"worker": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
taints=[schema.Taint(key="dedicated", value="worker", effect="NoSchedule")],
),
}


Expand Down Expand Up @@ -369,10 +413,8 @@ def _check_input(cls, data: Any) -> Any:
return data


class AzureNodeGroup(schema.Base):
instance: str
min_nodes: int
max_nodes: int
class AzureNodeGroup(NodeGroup):
pass


DEFAULT_AZURE_NODE_GROUPS = {
Expand Down Expand Up @@ -440,10 +482,7 @@ def _validate_tags(cls, value: Optional[Dict[str, str]]) -> Dict[str, str]:
return value if value is None else azure_cloud.validate_tags(value)


class AWSNodeGroup(schema.Base):
instance: str
min_nodes: int = 0
max_nodes: int
class AWSNodeGroup(NodeGroup):
gpu: bool = False
single_subnet: bool = False
permissions_boundary: Optional[str] = None
Expand Down Expand Up @@ -752,6 +791,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
instance_type=node_group.instance,
min_size=node_group.min_nodes,
max_size=node_group.max_nodes,
node_taints=node_group.taints,
preemptible=node_group.preemptible,
guest_accelerators=node_group.guest_accelerators,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ resource "google_container_node_pool" "main" {

oauth_scopes = local.node_group_oauth_scopes

dynamic "taint" {
for_each = local.merged_node_groups[count.index].node_taints
content {
key = taint.value.key
value = taint.value.value
effect = taint.value.effect
}
}

metadata = {
disable-legacy-endpoints = "true"
}
Expand All @@ -108,9 +117,4 @@ resource "google_container_node_pool" "main" {
tags = var.tags
}

lifecycle {
ignore_changes = [
node_config[0].taint
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,23 @@ variable "node_groups" {
min_size = 1
max_size = 1
labels = {}
node_taints = []
},
{
name = "user"
instance_type = "n1-standard-2"
min_size = 0
max_size = 2
labels = {}
node_taints = [] # TODO: Do this for other cloud providers
},
{
name = "worker"
instance_type = "n1-standard-2"
min_size = 0
max_size = 5
labels = {}
node_taints = []
}
]
}
Expand Down
22 changes: 22 additions & 0 deletions src/_nebari/stages/kubernetes_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,19 @@ def handle_units(cls, value: Optional[str]) -> float:
return byte_unit_conversion(value, "GiB")


class TolerationOperatorEnum(str, enum.Enum):
Equal = "Equal"
Exists = "Exists"

@classmethod
def to_yaml(cls, representer, node):
return representer.represent_str(node.value)


class Toleration(schema.Taint):
operator: TolerationOperatorEnum = TolerationOperatorEnum.Equal


class JupyterhubInputVars(schema.Base):
jupyterhub_theme: Dict[str, Any] = Field(alias="jupyterhub-theme")
jupyterlab_image: ImageNameTag = Field(alias="jupyterlab-image")
Expand All @@ -467,6 +480,9 @@ class JupyterhubInputVars(schema.Base):
cloud_provider: str = Field(alias="cloud-provider")
jupyterlab_preferred_dir: Optional[str] = Field(alias="jupyterlab-preferred-dir")
shared_fs_type: SharedFsEnum
node_taint_tolerations: Optional[List[Toleration]] = Field(
alias="node-taint-tolerations"
)

@field_validator("jupyterhub_shared_storage", mode="before")
@classmethod
Expand Down Expand Up @@ -634,6 +650,12 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
jupyterlab_default_settings=self.config.jupyterlab.default_settings,
jupyterlab_gallery_settings=self.config.jupyterlab.gallery_settings,
jupyterlab_preferred_dir=self.config.jupyterlab.preferred_dir,
node_taint_tolerations=[
Toleration(**taint.model_dump())
for taint in self.config.google_cloud_platform.node_groups[
"user"
].taints
], # TODO: support other cloud providers
shared_fs_type=(
# efs is equivalent to nfs in these modules
SharedFsEnum.nfs
Expand Down
13 changes: 13 additions & 0 deletions src/_nebari/stages/kubernetes_services/template/dask_gateway.tf
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ module "dask-gateway" {

forwardauth_middleware_name = var.forwardauth_middleware_name

cluster = {
scheduler_extra_pod_config = {
tolerations = [
{
key = "dedicated"
operator = "Equal"
value = "adamworker"
effect = "NoSchedule"
}
]
}
}

depends_on = [
module.kubernetes-nfs-server,
module.rook-ceph
Expand Down
11 changes: 11 additions & 0 deletions src/_nebari/stages/kubernetes_services/template/jupyterhub.tf
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ variable "idle-culler-settings" {
type = any
}

variable "node-taint-tolerations" {
description = "Node taint toleration"
type = list(object({
key = string
operator = string
value = string
effect = string
}))
}

variable "shared_fs_type" {
type = string
description = "Use NFS or Ceph"
Expand Down Expand Up @@ -175,6 +185,7 @@ module "jupyterhub" {
conda-store-service-name = module.kubernetes-conda-store-server.service_name
conda-store-jhub-apps-token = module.kubernetes-conda-store-server.service-tokens.jhub-apps
jhub-apps-enabled = var.jhub-apps-enabled
node-taint-tolerations = var.node-taint-tolerations

extra-mounts = {
"/etc/dask" = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,23 @@ variable "cluster" {
description = "dask gateway cluster defaults"
type = object({
# scheduler configuration
scheduler_cores = number
scheduler_cores_limit = number
scheduler_memory = string
scheduler_memory_limit = string
scheduler_extra_container_config = any
scheduler_extra_pod_config = any
scheduler_cores = optional(number, 1)
scheduler_cores_limit = optional(number, 1)
scheduler_memory = optional(string, "2 G")
scheduler_memory_limit = optional(string, "2 G")
scheduler_extra_container_config = optional(any, {})
scheduler_extra_pod_config = optional(any, {})
# worker configuration
worker_cores = number
worker_cores_limit = number
worker_memory = string
worker_memory_limit = string
worker_extra_container_config = any
worker_extra_pod_config = any
worker_cores = optional(number, 1)
worker_cores_limit = optional(number, 1)
worker_memory = optional(string, "2 G")
worker_memory_limit = optional(string, "2 G")
worker_extra_container_config = optional(any, {})
worker_extra_pod_config = optional(any, {})
# additional fields
idle_timeout = number
image_pull_policy = string
environment = map(string)
idle_timeout = optional(number, 1800) # 30 minutes
image_pull_policy = optional(string, "IfNotPresent")
environment = optional(map(string), {})
})
default = {
# scheduler configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,25 @@ def base_profile_extra_mounts():
}


def node_taint_tolerations():
tolerations = z2jh.get_config("custom.node-taint-tolerations")

if not tolerations:
return {}

return {
"tolerations": [
{
"key": taint["key"],
"operator": taint["operator"],
"value": taint["value"],
"effect": taint["effect"],
}
for taint in tolerations
]
}


def configure_user_provisioned_repositories(username):
# Define paths and configurations
pvc_home_mount_path = f"home/{username}"
Expand Down Expand Up @@ -519,6 +538,7 @@ def render_profile(profile, username, groups, keycloak_profilenames):
configure_user(username, groups),
configure_user_provisioned_repositories(username),
profile_kubespawner_override,
node_taint_tolerations(),
],
{},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ resource "helm_release" "jupyterhub" {
conda-store-jhub-apps-token = var.conda-store-jhub-apps-token
jhub-apps-enabled = var.jhub-apps-enabled
initial-repositories = var.initial-repositories
node-taint-tolerations = var.node-taint-tolerations
skel-mount = {
name = kubernetes_config_map.etc-skel.metadata.0.name
namespace = kubernetes_config_map.etc-skel.metadata.0.namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,13 @@ variable "initial-repositories" {
type = string
default = "[]"
}

variable "node-taint-tolerations" {
description = "Node taint toleration"
type = list(object({
key = string
operator = string
value = string
effect = string
}))
}
Loading
Loading