Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable overwrites of default environment variables #874

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions dask_kubernetes/operator/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,25 @@ def build_worker_deployment_spec(
"metadata": metadata,
"spec": spec,
}
env = [
{
"name": "DASK_WORKER_NAME",
"value": worker_name,
},
{
"name": "DASK_SCHEDULER_ADDRESS",
"value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786",
},
]
worker_env = {
"name": "DASK_WORKER_NAME",
"value": worker_name,
}
scheduler_env = {
"name": "DASK_SCHEDULER_ADDRESS",
"value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786",
}
for container in deployment_spec["spec"]["template"]["spec"]["containers"]:
if "env" in container:
container["env"].extend(env)
else:
container["env"] = env
if "env" not in container:
container["env"] = [worker_env, scheduler_env]
continue

container_env_names = [env_item["name"] for env_item in container["env"]]

if "DASK_WORKER_NAME" not in container_env_names:
container["env"].append(worker_env)
if "DASK_SCHEDULER_ADDRESS" not in container_env_names:
container["env"].append(scheduler_env)
return deployment_spec


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ metadata:
spec:
cluster: simple
worker:
replicas: 2
replicas: 1
spec:
containers:
- name: worker
Expand All @@ -23,3 +23,5 @@ spec:
env:
- name: WORKER_ENV
value: hello-world # We dont test the value, just the name
- name: DASK_WORKER_NAME
value: test-worker
89 changes: 72 additions & 17 deletions dask_kubernetes/operator/controller/tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

DIR = pathlib.Path(__file__).parent.absolute()


_EXPECTED_ANNOTATIONS = {"test-annotation": "annotation-value"}
_EXPECTED_LABELS = {"test-label": "label-value"}
DEFAULT_CLUSTER_NAME = "simple"
Expand All @@ -47,7 +46,6 @@ def gen_cluster(k8s_cluster, ns, gen_cluster_manifest):

@asynccontextmanager
async def cm(cluster_name=DEFAULT_CLUSTER_NAME):

cluster_path = gen_cluster_manifest(cluster_name)
# Create cluster resource
k8s_cluster.kubectl("apply", "-n", ns, "-f", cluster_path)
Expand Down Expand Up @@ -95,6 +93,36 @@ async def cm(job_file):
yield cm


@pytest.fixture()
def gen_worker_group(k8s_cluster, ns):
"""Yields an instantiated context manager for creating/deleting a worker group."""

@asynccontextmanager
async def cm(worker_group_file):
worker_group_path = os.path.join(DIR, "resources", worker_group_file)
with open(worker_group_path) as f:
worker_group_name = yaml.load(f, yaml.Loader)["metadata"]["name"]

# Create cluster resource
k8s_cluster.kubectl("apply", "-n", ns, "-f", worker_group_path)
while worker_group_name not in k8s_cluster.kubectl(
"get", "daskworkergroups.kubernetes.dask.org", "-n", ns
):
await asyncio.sleep(0.1)

try:
yield worker_group_name, ns
finally:
# Test: remove the wait=True, because I think this is blocking the operator
k8s_cluster.kubectl("delete", "-n", ns, "-f", worker_group_path)
while worker_group_name in k8s_cluster.kubectl(
"get", "daskworkergroups.kubernetes.dask.org", "-n", ns
):
await asyncio.sleep(0.1)

yield cm


def test_customresources(k8s_cluster):
assert "daskclusters.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd")
assert "daskworkergroups.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd")
Expand Down Expand Up @@ -671,32 +699,59 @@ async def test_object_dask_cluster(k8s_cluster, kopf_runner, gen_cluster):


@pytest.mark.anyio
async def test_object_dask_worker_group(k8s_cluster, kopf_runner, gen_cluster):
async def test_object_dask_worker_group(
k8s_cluster, kopf_runner, gen_cluster, gen_worker_group
):
with kopf_runner:
async with gen_cluster() as (cluster_name, ns):
async with (
gen_cluster() as (cluster_name, ns),
gen_worker_group("simpleworkergroup.yaml") as (
additional_workergroup_name,
_,
),
):
cluster = await DaskCluster.get(cluster_name, namespace=ns)
additional_workergroup = await DaskWorkerGroup.get(
additional_workergroup_name, namespace=ns
)

worker_groups = []
while not worker_groups:
worker_groups = await cluster.worker_groups()
await asyncio.sleep(0.1)
assert len(worker_groups) == 1 # Just the default worker group
wg = worker_groups[0]
assert isinstance(wg, DaskWorkerGroup)
worker_groups = worker_groups + [additional_workergroup]

pods = []
while not pods:
pods = await wg.pods()
await asyncio.sleep(0.1)
assert all([isinstance(p, Pod) for p in pods])
for wg in worker_groups:
assert isinstance(wg, DaskWorkerGroup)

deployments = []
while not deployments:
deployments = await wg.deployments()
await asyncio.sleep(0.1)
assert all([isinstance(d, Deployment) for d in deployments])
deployments = []
while not deployments:
deployments = await wg.deployments()
await asyncio.sleep(0.1)
assert all([isinstance(d, Deployment) for d in deployments])

assert (await wg.cluster()).name == cluster.name
pods = []
while not pods:
pods = await wg.pods()
await asyncio.sleep(0.1)
assert all([isinstance(p, Pod) for p in pods])

assert (await wg.cluster()).name == cluster.name

for deployment in deployments:
assert deployment.labels["dask.org/cluster-name"] == cluster.name
for env in deployment.spec["template"]["spec"]["containers"][0][
"env"
]:
if env["name"] == "DASK_WORKER_NAME":
if wg.name == additional_workergroup_name:
assert env["value"] == "test-worker"
else:
assert env["value"] == deployment.name
if env["name"] == "DASK_SCHEDULER_ADDRESS":
scheduler_service = await cluster.scheduler_service()
assert f"{scheduler_service.name}.{ns}" in env["value"]


@pytest.mark.anyio
Expand Down
Loading