Skip to content

Dask executor

Documentation in dask executor docstrings.

Module executors.dask_executor


Overview

Handles execution of surrogate workflow on Dask. Supports both SLURMCluster and LocalCluster for distributed task execution. SLURMCluster: https://jobqueue.dask.org/en/latest/index.html


Clusters

Local cluster

Can be used for running on a local machine with multiple cores. Useful for testing or small scale runs.

Arguments:

n_workers: 2,
threads_per_worker: 1,
memory_limit: '12GB', 
processes: 1

Example configuration: /configs/example_dask_local.yaml

SLURM cluster

Arguments for the SLURM workers.

account: 'project_xxx', 
queue: 'medium', 
cores: 1, 
memory: '12GB', 
processes: 1, 
walltime: '00:20:00',
config_name: 'slurm', 
interface: 'ib0', 

Example configuration: /configs/example_dask_slurm.yaml

Notes

Other arguments:

job_script_prologue: ['module load your-modules-here',], 
job_extra_directives: [
    '-o tmp_path_hm/worker_out_MishkaRunner_1/%x.%j.out', 
    '-e tmp_path_hm/worker_out_MishkaRunner_1/%x.%j.err'], 

DaskExecutor

DaskExecutor(*args, **kwargs)

Bases: Executor

Initializes the DaskExecutor.

Parameters:

Name Type Description Default
base_run_dir str

Base directory for storing run outputs.

required
sampler_config dict

Arguments for the sampler, including its type.

required
runner_config dict

Arguments for the runner.

required
*args

Additional positional arguments.

()
**kwargs

Additional keyword arguments, including: - type (str): Type of executor. - scale_n_jobs (int): Number of jobs to scale the cluster to. - SLURMcluster_config (dict): Arguments for SLURMCluster. - LocalCluster_config (dict): Arguments for LocalCluster. - block_unitil_cluster_started (bool): Whether to block until the cluster is fully started.

{}
Source code in src/enchanted_surrogates/executors/dask_executor.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def __init__(self, *args, **kwargs):
    """
    Initializes the DaskExecutor.

    Args:
        base_run_dir (str): Base directory for storing run outputs.
        sampler_config (dict): Arguments for the sampler, including its type.
        runner_config (dict): Arguments for the runner.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments, including:
            - type (str): Type of executor.
            - scale_n_jobs (int): Number of jobs to scale the cluster to.
            - SLURMcluster_config (dict): Arguments for SLURMCluster.
            - LocalCluster_config (dict): Arguments for LocalCluster.
            - block_unitil_cluster_started (bool): Whether to block until the cluster is fully started.
    """
    super().__init__(*args, **kwargs)
    log.info("INITIALISING DASK EXECUTOR")
    self.scale_n_jobs = kwargs.get("scale_n_jobs", 1)
    self.timeout = kwargs.get("timeout", 1e10)
    self.SLURMcluster_config = kwargs.get("SLURMcluster_config")
    self.LocalCluster_config = kwargs.get("LocalCluster_config")
    self.block_until_cluster_started = kwargs.get(
        "block_until_cluster_started", False
    )  # for debugging purposes only
    self.cluster = None
    self.client = None
    self.expected_number_of_workers = None
    self.slurm_job_ids = set()
    self.is_closed = False

clean

clean()

Cleans up resources by shutting down the Dask cluster.

This method is intended to be called when the executor is no longer needed.

Source code in src/enchanted_surrogates/executors/dask_executor.py
447
448
449
450
451
452
453
454
455
456
457
458
def clean(self):
    """
    Cleans up resources by shutting down the Dask cluster.

    This method is intended to be called when the executor is no longer needed.
    """
    if self.is_closed:
        log.debug("Trying to close cluster that has already been closed!")
        return

    self.shutdown_cluster()
    self.is_closed = True

execute

execute(input, sampler)

Starts the execution of simulation tasks using the configured Dask cluster.

This method initializes the base run directory, checks for existing data to avoid overwrites, and submits tasks to the Dask cluster in batches. It collects results from completed tasks and writes them to a CSV file.

Raises:

Type Description
FileExistsError

If the base run directory contains a file indicating a completed run.

Source code in src/enchanted_surrogates/executors/dask_executor.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
def execute(self, input: list[(str, dict)], sampler):
    """
    Starts the execution of simulation tasks using the configured Dask cluster.

    This method initializes the base run directory, checks for existing data to avoid overwrites,
    and submits tasks to the Dask cluster in batches. It collects results from completed tasks
    and writes them to a CSV file.

    Raises:
        FileExistsError: If the base run directory contains a file indicating a completed run.
    """
    assert sampler

    inputlist = list(input)
    self.base_run_dir = os.path.dirname(inputlist[0][0])

    if not self.client:
        self.start_cluster()
    log.info("CLUSTER STARTED")

    # keep futures for BayesianOptimizationSampler
    futures = self.submit_batch(inputlist)

    sampler_type = (
        getattr(sampler, "type", None)
        or getattr(sampler, "__class__", None).__name__
    )

    if sampler_type in {"BayesianOptimizationSampler"}:
        try:
            wait(futures, timeout=self.timeout)
        except Exception:
            print("Timeout or error while waiting for BO batch; continuing.")

        if getattr(sampler, "plot_GPR_flag", False):
            try:
                sampler.build_result_dictionary(self.base_run_dir)
                sampler.plot_frequency = 1
                sampler.train_surrogate()
            except Exception as e:
                log.error("Error during sampler postprocessing:", e)
    else:
        dfs = []
        num_success = 0
        total = len(futures)

        log.info(f"Collecting results from {total} futures...")

        for i, future in enumerate(as_completed(futures), start=1):
            try:
                result = future.result()
            except Exception as e:
                print(f"[{i}/{total}] Future failed with exception:", e)
                continue

            if isinstance(result, dict) and result.get("success") is True:
                num_success += 1

            try:
                df = pd.DataFrame({k: [v] for k, v in result.items()})
                dfs.append(df)
            except Exception as e:
                log.error("Failed to convert result to DataFrame:", e)
                continue

            log.info(
                f"[{i}/{total}] Futures Completed ({(i / total) * 100:.1f}%) | "
                f"[{num_success}/{i}] Futures Succeeded"
            )
            log.info("_" * 100)

find_line_in_seff_output

find_line_in_seff_output(lines, entry)

Helper function to quickly find the required line in the seff output

Parameters:

Name Type Description Default
lines list

list of lines

required
entry str

the entry that is being looked for

required

Returns: str: time or percentage from the corresponding line, defaults to ""

Source code in src/enchanted_surrogates/executors/dask_executor.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def find_line_in_seff_output(self, lines, entry):
    """
    Helper function to quickly find the required line in the seff output

    Params:
        lines (list): list of lines
        entry (str): the entry that is being looked for
    Returns:
        str: time or percentage from the corresponding line, defaults to ""
    """
    return next(
        (
            line.replace(entry, "").strip()
            for line in lines
            if line.startswith(entry)
        ),
        "",
    )

get_all_dask_job_ids

get_all_dask_job_ids()

Runs squeue to figure out all jobs from the cluster Returns: list: A list of Dask job IDs.

Source code in src/enchanted_surrogates/executors/dask_executor.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def get_all_dask_job_ids(self):
    """
    Runs squeue to figure out all jobs from the cluster
    Returns:
        list: A list of Dask job IDs.
    """
    try:
        jobs = []
        result = subprocess.run(
            ["squeue", "--me"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        dask_lines = [
            line
            for line in result.stdout.splitlines()
            if "sys/dash" not in line and "enc_dask_worker" in line
        ]

        if not dask_lines:
            log.debug("No Dask jobs found in queue.")
            return []

        for line in dask_lines:
            fields = line.split()
            jobs.append(fields[0])

        return jobs

    except Exception as e:
        if self.is_running_on_slurm():
            log.warning(f"Error while checking squeue: {e}")
        return []

get_slurm_usage_info

get_slurm_usage_info(job_id=None)

Parameters:

Name Type Description Default
job_id list[int]

If you wish to only find the slurm usage info from one specific job pass this

None

Returns: list: dictionary containing the output info from running seff

Source code in src/enchanted_surrogates/executors/dask_executor.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def get_slurm_usage_info(self, job_id=None):
    """
    Params:
        job_id (list[int]): If you wish to only find the slurm usage info from one specific job pass this
    Returns:
        list: dictionary containing the output info from running seff
    """
    job_ids = job_id if job_id else self.get_all_dask_job_ids()
    job_info = []

    for job_id in job_ids:
        try:
            result = subprocess.run(
                ["seff", str(job_id)],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
            )
            output = result.stdout
            seff_lines = output.splitlines()

            if len(output.strip()) == 0:
                continue

            cpu_time = self.find_line_in_seff_output(seff_lines, "CPU Utilized:")
            cpu_efficiency = self.find_line_in_seff_output(
                seff_lines, "CPU Efficiency:"
            )
            memory_used = self.find_line_in_seff_output(
                seff_lines, "Memory Utilized:"
            )
            memory_efficiency = self.find_line_in_seff_output(
                seff_lines, "Memory Efficiency:"
            )

            hours, minutes, seconds = map(int, cpu_time.split(":"))
            cpu_secs = hours * 3600 + minutes * 60 + seconds

            job_info.append(
                {
                    "cpu_time": cpu_time,
                    "cpu_time_seconds": cpu_secs,
                    "cpu_efficiency": cpu_efficiency,
                    "memory_efficiency": memory_efficiency,
                    "memory_used": memory_used,
                    "job_id": job_id,
                }
            )
        except Exception as e:
            if self.is_running_on_slurm():
                log.error(
                    f"Error fetching SLURM resource usage for job {job_id}: {e}"
                )
            else:
                log.error("Not running on SLURM. skipping resources")
                return [
                    {
                        "cpu_time": "00:00:00",
                        "cpu_time_seconds": 0,
                        "cpu_efficiency": "100%",
                        "memory_efficiency": "100%",
                        "memory_used": "0",
                        "job_id": 0,
                    }
                ]

    return job_info

is_running_on_slurm

is_running_on_slurm()

Checks if code is running on slurm or locally. This is done via checking if seff exists. Retuns: bool: true is on slurm false otherwise

Source code in src/enchanted_surrogates/executors/dask_executor.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def is_running_on_slurm(self):
    """
    Checks if code is running on slurm or locally. This is done via checking if seff exists.
    Retuns:
        bool: true is on slurm false otherwise
    """
    try:
        proc = subprocess.run(
            ["which", "seff"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            timeout=1,
        )
    except Exception:
        return False

    return proc.returncode == 0

shutdown_cluster

shutdown_cluster()

Shuts down the Dask cluster, including the scheduler and workers.

Note

This will also shut down the scheduler, which may not be desired if the scheduler is controlling other clusters. To only shut down the workers, use a different method.

Source code in src/enchanted_surrogates/executors/dask_executor.py
460
461
462
463
464
465
466
467
468
def shutdown_cluster(self):
    """
    Shuts down the Dask cluster, including the scheduler and workers.

    Note:
        This will also shut down the scheduler, which may not be desired if the scheduler
        is controlling other clusters. To only shut down the workers, use a different method.
    """
    self.client.shutdown()

start_cluster

start_cluster()

Starts a Dask cluster using either SLURMCluster or LocalCluster.

If SLURMCluster is used, it sets up SLURM-specific configurations, including output directories for worker logs. If LocalCluster is used, it initializes a local Dask cluster.

Parameters:

Name Type Description Default
slurm_out_dir str

Directory for SLURM output logs. Defaults to None.

required

Raises:

Type Description
ValueError

If no workers are successfully started.

Warning

If fewer workers than expected are started.

Source code in src/enchanted_surrogates/executors/dask_executor.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def start_cluster(self):
    """
    Starts a Dask cluster using either SLURMCluster or LocalCluster.

    If SLURMCluster is used, it sets up SLURM-specific configurations, including output directories
    for worker logs. If LocalCluster is used, it initializes a local Dask cluster.

    Args:
        slurm_out_dir (str, optional): Directory for SLURM output logs. Defaults to None.

    Raises:
        ValueError: If no workers are successfully started.
        Warning: If fewer workers than expected are started.
    """
    log.info("Creating a cluster...")
    slurm_out_dir = LoggerConfig().log_dir

    if self.SLURMcluster_config:
        self.expected_number_of_workers = self.scale_n_jobs * int(
            self.SLURMcluster_config.get("processes", 1)
        )

        log.info(f"Output of SLURM workers saved in: {slurm_out_dir}")
        self.cluster = SLURMCluster(silence_logs=False, **self.SLURMcluster_config)
        self.cluster.scale(self.scale_n_jobs)
        log.debug(f"The job script for a worker is:\n{self.cluster.job_script()}")

        self.client = Client(self.cluster, timeout=180)

        # Register the log plugin
        plugin = SLURMLogPlugin(LoggerConfig())
        self.client.register_plugin(plugin, name="LogPlugin")

        log.info(f"SCHEDULER ADDRESS: {self.cluster.scheduler_address}")
        log.info(f"DASHBOARD LINK: {self.client.dashboard_link}")

        if self.block_until_cluster_started:
            log.info("WAIT UNTILL ALL dask-wor JOBS ARE RUNNING")
            self.wait_for_all_dask_jobs_running()

    elif self.LocalCluster_config:
        self.expected_number_of_workers = self.LocalCluster_config["n_workers"]
        self.cluster = LocalCluster(silence_logs=False, **self.LocalCluster_config)
        self.client = Client(self.cluster)

        # Register the log plugin
        plugin = DaskLocalLogPlugin(LoggerConfig())
        self.client.register_plugin(plugin, name="LogPlugin")

    if self.block_until_cluster_started:
        log.info(
            f"Waiting for {self.expected_number_of_workers} workers to start..."
        )
        for i in range(1, self.expected_number_of_workers + 2):
            if i == self.expected_number_of_workers + 1:
                timeout_ = 3
                try:
                    self.client.wait_for_workers(i, timeout=timeout_)
                    log.warning(
                        f"MORE WORKERS WERE STARTED THAN THE EXPECTED {self.expected_number_of_workers}"
                    )
                except TimeoutError:
                    log.error(
                        f"IN {timeout_} SEC NO UNEXPECTED WORKERS WERE STARTED.\n"
                    )
            else:
                self.client.wait_for_workers(
                    i, timeout=self.expected_number_of_workers + 120
                )
                log.info(
                    f"Connected to {i} workers out of expected {self.expected_number_of_workers}.\n"
                )

        workers = self.client.scheduler_info()["workers"]
        log.info("SOME WORKER INFORMATION:")
        for addr, info in workers.items():
            log.info(f"Worker {addr}:")
            log.info(f"  CPUs: {info['nthreads']}")
            log.info(f"  Memory: {info['memory_limit'] / 1e9:.2f} GB")
            log.info(f"  Resources: {info.get('resources', {})}\n")

    # Only for SLURM cluster
    if hasattr(self.cluster, "workers"):
        try:
            self.slurm_job_ids.update(self.cluster.workers.keys())
        except Exception:
            pass

submit_batch

submit_batch(run_dir_sample_pairs, base_run_dir=None, client=None, include_fut_to_rundir=False)

Submits a batch of simulation tasks to the Dask cluster.

Each task is submitted with its own unique run directory. The tasks are executed asynchronously, and their futures are returned for tracking.

Parameters:

Name Type Description Default
run_dir_sample_pairs list

List of rundir, sample parameters for the simulation tasks.

required

Returns:

Name Type Description
list

List of futures representing the submitted tasks.

Source code in src/enchanted_surrogates/executors/dask_executor.py
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
def submit_batch(
    self,
    run_dir_sample_pairs,
    base_run_dir=None,
    client=None,
    include_fut_to_rundir=False,
):
    """
    Submits a batch of simulation tasks to the Dask cluster.

    Each task is submitted with its own unique run directory. The tasks are executed
    asynchronously, and their futures are returned for tracking.

    Args:
        run_dir_sample_pairs (list): List of rundir, sample parameters for the simulation tasks.

    Returns:
        list: List of futures representing the submitted tasks.
    """
    if not client:
        client = self.client
    assert client is not None

    futures = []
    fut_to_rundir = {}
    for run_dir, sample_params in run_dir_sample_pairs:
        new_future = client.submit(
            run_simulation_task, self.runner_config, run_dir, sample_params
        )
        futures.append(new_future)
        fut_to_rundir[new_future.key] = run_dir

    log.info(
        f"{len(futures)} DASK FUTURES SUBMITTED for runner {self.runner_config['type']}"
    )
    if include_fut_to_rundir:
        return futures, fut_to_rundir
    return futures

wait_for_all_dask_jobs_running

wait_for_all_dask_jobs_running(poll_interval=1)

Waits for all Dask jobs submitted to SLURM to reach the RUNNING state.

This method repeatedly checks the SLURM job queue for jobs with the prefix 'dask-wor'. If any job is not in the RUNNING state, it waits and retries until all jobs are running.

Parameters:

Name Type Description Default
poll_interval int

Time interval (in seconds) between checks. Defaults to 1.

1

Raises:

Type Description
Exception

If an error occurs while checking the SLURM queue.

Source code in src/enchanted_surrogates/executors/dask_executor.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
def wait_for_all_dask_jobs_running(self, poll_interval=1):
    """
    Waits for all Dask jobs submitted to SLURM to reach the RUNNING state.

    This method repeatedly checks the SLURM job queue for jobs with the prefix 'dask-wor'.
    If any job is not in the RUNNING state, it waits and retries until all jobs are running.

    Args:
        poll_interval (int, optional): Time interval (in seconds) between checks. Defaults to 1.

    Raises:
        Exception: If an error occurs while checking the SLURM queue.
    """
    log.info("Waiting for all Dask jobs to enter RUNNING state...")

    while True:
        try:
            # Run squeue --me and capture output
            result = subprocess.run(
                ["squeue", "--me"],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
            )
            output = result.stdout

            # Filter lines containing 'dask-wor'
            dask_lines = [
                line for line in output.splitlines() if "dask-wor" in line
            ]

            if not dask_lines:
                log.info("No Dask jobs found in queue.")
                time.sleep(poll_interval)
                continue

            # Check job states
            all_running = True
            for line in dask_lines:
                fields = line.split()
                job_id = fields[0]
                job_state = fields[4]  # Typically the 5th column is state

                if job_state == "PD":
                    log.info("=" * 100)
                    log.info("\n".join(dask_lines))
                    log.info(f"Job {job_id} is in state {job_state} — waiting...")
                    log.info("=" * 100)
                    all_running = False
                    break

            if all_running:
                log.info("All Dask jobs are RUNNING.")
                return

            time.sleep(poll_interval)

        except Exception as e:
            log.error(f"Error while checking squeue: {e}")
            time.sleep(poll_interval)