Skip to content

Supervisor

Supervisor is where the main loop of Enchanted Surrogates is ran. Supervisor orchestrates use of samplers, executors and runners. See the chart below for overall structure of the code.

Workflow chart

Configuration of supervisor

Supervisor needs base_run_dir defined in the configuration file. Example as follows:

supervisor:
    base_run_dir: "path/to/folder"
    run_order:
    -   executor: ...
        sampler: ...
        runner: ...

Optional attributes

Also, it is possible to define that enchanted_dataset summary file combining all run results is parquet instead of csv. Csv is default and does not require any configuration.

supervisor:
    summary_datatype: "parquet" # csv by default

Hdf5 storage file is not saved if type for it is None. It is created in every other case.

storage:
    type: "hdf5" # or "None"

It is possible to delete unnecessary files from base_run_dir and keep only wanted files. By default all is saved. Option custom saves only described files. None does not save any files.

Note: enchanted_dataset.csv and runs.h5 are saved by default.

supervisor:
  save_files: "all" # or "custom" or "none"
  # if using custom, only described files are saved
  save_files_list:
    - file.txt
    - file2.txt 

See config folder for a config file examples.

Module supervisor.supervisor

Supervisor module.

Provides the Supervisor class, which coordinates configuration, execution, sampling, and result aggregation for simulation runs.

Supervisor

Supervisor(args, config_path=None)

Creates supervisor which handles configuration, running and file output of the program.

Attributes:

Name Type Description
args Namespace

Namespace containing the configuration parameters

executor Executor

Executor for this run

sampler Sampler

Sampler for this run

base_run_dir str

Path where runner saves result files

Methods:

Name Description
start

Starts the simulation process. Main function of supervisor.

create_base_run_dir

Creates base directory for simulation run results.

all_processes_done

Returns true when all simulations are done.

wait_all_processes

Waits in while loop until all simulations are done.

create_dataset

Creates pandas DataFrame that includes all the "enchanted_datapoints.csv" files of running directories.

create_hdf5

Creates hdf5 structured file that includes numeric data of enchanted_dataset and metadata

Initializes supervisor and sets class attributes.

Parameters:

Name Type Description Default
args Namespace

Namespace containing the configuration parameters.

required
config_path str or None

Optional path for configuration file where configuration is fetched from.

None
Source code in src/enchanted_surrogates/supervisor/supervisor.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def __init__(self, args, config_path=None):
    """
    Initializes supervisor and sets class attributes.

    Arguments:
        args (argparse.Namespace): Namespace containing the configuration parameters.
        config_path (str or None): Optional path for configuration file where
            configuration is fetched from.
    """
    self.args = args
    executors = import_executors(args)
    samplers = import_samplers(args)
    group_configs = import_run_groups(args)

    self.groups: list[RunGroup] = []
    for group in group_configs:
        run_group = RunGroup(
            executors[group["executor"]],
            samplers[group["sampler"]],
            args.runners[group["runner"]],
        )
        run_group.executor.runner_config = run_group.runner
        self.groups.append(run_group)

    self.base_run_dir = args.supervisor.get("base_run_dir")
    self.run_mode = args.supervisor.get("run_mode", "fresh")
    self.save_files_arg = args.supervisor.get("save_files", "all")

    if self.base_run_dir is None:
        if sys.stdout.isatty():
            self.base_run_dir = "base_run_dir"
            log.warning(
                "No config for base_run_dir was found, "
                "created base_run_dir folder to working directory"
            )
        else:
            raise ValueError(
                "base_run_dir is not set in the provided configuration"
            )

    self.local_storage = args.supervisor.get("local_storage")

    if self.local_storage and not os.path.exists(self.local_storage):
        env = self.local_storage
        self.local_storage = os.environ.get(env)

        if not self.local_storage:
            log.warning(f"Local storage environment variable {env} not set, ignoring...")

    self.data_dir = os.path.join(self.base_run_dir, "data")
    self.previous_run_file = os.path.join(self.base_run_dir, "enchanted_run.yaml")
    self.previous_run_data = None

    if self.run_mode in ("resume", "extend"):
        self.previous_run_data = RunData.load(self.previous_run_file)
        if self.previous_run_data:
            if len(self.groups) > self.previous_run_data.depth:
                # Extending should generate budget worth of new samples so add
                # already submitted amount to the current budget
                if self.run_mode == "extend":
                    self.groups[
                        self.previous_run_data.depth
                    ].sampler.budget += self.previous_run_data.submitted_samples

                self.groups[self.previous_run_data.depth].sampler.skip(
                    self.previous_run_data.batch_number + 1
                )
        else:
            raise RuntimeError(
                "Tried to continue from previous sampling but no "
                " enchanted_run.yaml was found"
            )

        self.continue_with_base_run_dir(config_path)
    else:
        self.create_base_run_dir(self.base_run_dir, config_path)

all_processes_done

all_processes_done(name_filter=None)

Monitors simulation processes and returns boolean describing state. Helper function for wait_all_processes.

Parameters:

Name Type Description Default
filter str or None

Optional filter used to limit checking to run directories containing this text. If None (default), all run directories are checked.

required

Returns: True when all simulations are done. Helper function for wait_all_processes. Checks inside base_run_dir if folders inside it contain "enchanted_datapoint.csv" files. False If any runner has not yet created the csv file

Source code in src/enchanted_surrogates/supervisor/supervisor.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
def all_processes_done(self, name_filter=None):
    """
    Monitors simulation processes and returns boolean describing state.
    Helper function for wait_all_processes.

    Args:
        filter (str or None): Optional filter used to limit checking to run directories
            containing this text. If None (default), all run directories are checked.
    Returns:
        True when all simulations are done. Helper function for
            wait_all_processes. Checks inside base_run_dir if folders inside it
            contain "enchanted_datapoint.csv" files.
        False If any runner has not yet created the csv file
    """

    for name in os.listdir(self.data_dir):
        if not name_filter or str(name_filter) in str(name):
            folder_path = os.path.join(self.data_dir, name)
            if os.path.isdir(folder_path):
                datapoint_file = os.path.join(
                    folder_path, "enchanted_datapoint.csv"
                )
                if not os.path.isfile(datapoint_file):
                    return False

    return True

batch_dirs_done

batch_dirs_done(run_dirs)

Checks if enchanted_datapoint.csv files exist in the directories list given

Attributes:

Name Type Description
run_dirs list[str]

List of running directories within the batch

Return

False if any of the datapoint files in the run_dirs is missing True if all datapoint files are found

Source code in src/enchanted_surrogates/supervisor/supervisor.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def batch_dirs_done(self, run_dirs: list[str]) -> bool:
    """
    Checks if enchanted_datapoint.csv files exist in the directories list given

    Attributes:
        run_dirs (list[str]): List of running directories within the batch

    Return:
        False if any of the datapoint files in the run_dirs is missing
        True if all datapoint files are found
    """
    for d in run_dirs:
        if not os.path.isfile(os.path.join(d, "enchanted_datapoint.csv")):
            return False
    return True

continue_with_base_run_dir

continue_with_base_run_dir(config_path)

Deletes old unfinished bathes prompting the user if they want to keep them Creates a base_run_dir if one does not exist

Attributes:

Name Type Description
config_path str or None

Optional path for configuration file where configuration is fetched from

Source code in src/enchanted_surrogates/supervisor/supervisor.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def continue_with_base_run_dir(self, config_path):
    """
    Deletes old unfinished bathes prompting the user if they want to keep them
    Creates a base_run_dir if one does not exist

    Attributes:
        config_path (str or None): Optional path for configuration file where
            configuration is fetched from
    """

    if not os.path.exists(self.base_run_dir):
        self.create_base_run_dir(self.base_run_dir, config_path)
        return

    if not os.path.exists(self.data_dir):
        os.makedirs(self.data_dir, exist_ok=True)

    dirs = glob.glob(f"{self.data_dir}/"
        + f"d{self.previous_run_data.depth}_b{self.previous_run_data.batch_number + 1}*")

    if not dirs:
        return

    for path in dirs:
        shutil.rmtree(path)

create_base_run_dir

create_base_run_dir(base_run_dir, config_path)

Creates base directory for simulation run results. Checks if base_run_dir is empty. Prompts user option to delete existing data in base_run_dir. Execution is stopped if user chooses to not delete files. Copies config_file to base_run_dir if config_file was provided.

Attributes:

Name Type Description
base_run_dir str

Path where runner saves result files

config_path str or None

Optional path for configuration file where configuration is fetched from.

Source code in src/enchanted_surrogates/supervisor/supervisor.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def create_base_run_dir(self, base_run_dir, config_path):
    """
    Creates base directory for simulation run results. Checks if base_run_dir
    is empty. Prompts user option to delete existing data in base_run_dir.
    Execution is stopped if user chooses to not delete files. Copies config_file
    to base_run_dir if config_file was provided.


    Attributes:
        base_run_dir (str): Path where runner saves result files
        config_path (str or None): Optional path for configuration file where
            configuration is fetched from.
    """

    # Make sure that there is nothing in base_run_dir
    if os.path.exists(base_run_dir):
        if next(os.scandir(base_run_dir), None):
            if sys.stdout.isatty():
                value = input(
                    str(os.path.abspath(base_run_dir))
                    + "\nFolders have content. "
                    + "Do you want to delete data in existing folders? y/N "
                )
            else:
                print(
                    str(os.path.abspath(base_run_dir))
                    + "\nFolders have content. If you wish to continue, go delete them"
                )
                value = "n"

            if value.lower() in ("y", "yes"):
                shutil.rmtree(base_run_dir)
                print("base_run_dir was deleted")
            else:
                print("No content was deleted. Enchanted surrogates is exited.")
                sys.exit(1)

    # Create base run dir and data dir inside it
    os.makedirs(base_run_dir, exist_ok=True)
    os.makedirs(self.data_dir, exist_ok=True)

    # Move config path to base_run_dir if config path is given
    if config_path is not None:
        os.makedirs(os.path.join(base_run_dir, "config"), exist_ok=True)
        config_dir = os.path.join(base_run_dir, "config")

        new_config_path = os.path.join(config_dir, os.path.basename(config_path))
        print(f"Moving config file... from {config_path} to {new_config_path}")
        try:
            shutil.copy(config_path, new_config_path)
        except OSError as exe:
            warnings.warn(
                "Failed to copy configuration file to base run directory.\n"
                "Try using the full path to the config file. \n"
                f"Source: '{config_path}'\n"
                f"Target: '{new_config_path}'\n"
                f"Error type: {type(exe).__name__}\n"
                f"Error message: {exe}"
            )

create_dataset

create_dataset()

Creates pandas DataFrame that includes all the "enchanted_datapoints.csv" files of running directories inside base_run_dir.

Return

pandas.DataFrame containing all the enchanted_datapoint.csv files created by runners.

Source code in src/enchanted_surrogates/supervisor/supervisor.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
def create_dataset(self):
    """
    Creates pandas DataFrame that includes all the "enchanted_datapoints.csv"
    files of running directories inside base_run_dir.

    Return:
        pandas.DataFrame containing all the enchanted_datapoint.csv files
        created by runners.

    """
    enchanted_dataset = pd.DataFrame()

    for name in os.listdir(self.data_dir):
        folder_path = os.path.join(self.data_dir, name)
        if os.path.isdir(folder_path):
            datapoint_file = os.path.join(folder_path, "enchanted_datapoint.csv")
            if os.path.isfile(datapoint_file):
                enchanted_datapoint = pd.read_csv(datapoint_file)
                enchanted_dataset = pd.concat(
                    [enchanted_dataset, enchanted_datapoint]
                )
    return enchanted_dataset

create_hdf5

create_hdf5(enchanted_dataset)

Creates hdf5 and saves storage file in base_run_dir with name runs.h5 Includes aggregated data from enchanted_dataset and run specific data in structured format. Dataset has only numeric values, column headers are saved separately in in same location. Metadata includes types for sampler, executor and runner.

Attributes:

Name Type Description
- enchanted_dataset (pd.DataFrame

Dataframe containing all run results

Source code in src/enchanted_surrogates/supervisor/supervisor.py
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
def create_hdf5(self, enchanted_dataset: pd.DataFrame):
    """
    Creates hdf5 and saves storage file in base_run_dir with name runs.h5
    Includes aggregated data from enchanted_dataset and run specific data
    in structured format. Dataset has only numeric values, column headers
    are saved separately in in same location. Metadata includes types for
    sampler, executor and runner.

    Attributes:
        - enchanted_dataset (pd.DataFrame): Dataframe containing all run results

    """
    h5_path = os.path.join(self.base_run_dir, "runs.h5")

    with h5py.File(h5_path, "w") as f:
        # Aggregated dataset
        agg_group = f.create_group("data/aggregated")

        agg_group.create_dataset(
            "values",
            data=enchanted_dataset.select_dtypes(include=[np.number]).to_numpy(),
        )

        agg_group.create_dataset(
            "columns",
            data=np.array(
                enchanted_dataset.select_dtypes(include=[np.number]).columns,
                dtype=h5py.string_dtype(encoding="utf-8"),
            ),
        )

        # Run directory datasets
        runs_group = f.create_group("data/runs")

        for name in os.listdir(self.data_dir):
            folder_path = os.path.join(self.data_dir, name)
            csv_path = os.path.join(folder_path, "enchanted_datapoint.csv")

            if not os.path.isfile(csv_path):
                continue

            df = pd.read_csv(csv_path)
            run_group = runs_group.create_group(name)

            # Select only numeric values
            numeric_df = df.select_dtypes(include=[np.number])

            run_group.create_dataset("values", data=numeric_df.to_numpy())

            run_group.create_dataset(
                "columns",
                data=np.array(numeric_df.columns, dtype=h5py.string_dtype("utf-8")),
            )

        # Metadata
        meta_group = f.create_group("metadata")
        run_groups = meta_group.create_group("run_groups")
        for i, run_group in enumerate(self.groups):
            meta_run_group = run_groups.create_group(str(i))
            meta_run_group.attrs["executor"] = str(
                run_group.executor.__class__.__name__
            )
            meta_run_group.attrs["sampler"] = str(
                run_group.sampler.__class__.__name__
            )
            meta_run_group.attrs["runner"] = str(run_group.runner.get("type"))

delete_unwanted_files

delete_unwanted_files(argument, base_dir=None)

Deletes files according to command given.

Source code in src/enchanted_surrogates/supervisor/supervisor.py
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
def delete_unwanted_files(self, argument: str, base_dir: str | None = None):
    """
    Deletes files according to command given.
    """
    default_list = ["enchanted_dataset.csv", "runs.h5"]
    if argument == "all":
        return

    if argument == "custom":
        saved_list = import_saved_files_list(self.args)
        allowed_files = set(default_list) | set(saved_list)
    elif argument == "none":
        allowed_files = set(default_list)
    else:
        return

    if base_dir == None:
        base_dir = self.base_run_dir

    for root, dirs, files in os.walk(base_dir, topdown=False):
        # Remove files
        for file in files:
            if file not in allowed_files:
                file_path = os.path.join(root, file)
                os.remove(file_path)
        # Remove dirs
        for dir_ in dirs:
            dir_path = os.path.join(root, dir_)
            if not os.listdir(dir_path):
                os.rmdir(dir_path)

fetch_from_local_storage

fetch_from_local_storage()

Moves all files from local_storage to base_run_dir, if local_storage is defined.

Source code in src/enchanted_surrogates/supervisor/supervisor.py
577
578
579
580
581
582
583
584
585
def fetch_from_local_storage(self):
    """
    Moves all files from local_storage to base_run_dir, if local_storage is defined.
    """
    if self.local_storage:
        for item in os.listdir(self.local_storage):
            src = os.path.join(self.local_storage, item)
            dst = os.path.join(self.base_run_dir, item)
            shutil.move(src, dst)

load_batch_to_df

load_batch_to_df(run_dirs)

Creates pd.DataFrame combining enchanted_datapoint.csv files in given path list folders

Attributes:

Name Type Description
run_dirs list[str]

List of running directories within the batch

Returns:

Type Description

pd.DataFrame containing batch datapoints combined

Source code in src/enchanted_surrogates/supervisor/supervisor.py
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
def load_batch_to_df(self, run_dirs: list[str]):
    """
    Creates pd.DataFrame combining enchanted_datapoint.csv files in given path list folders

    Attributes:
        run_dirs (list[str]): List of running directories within the batch

    Returns:
        pd.DataFrame containing batch datapoints combined
    """
    dfs = []
    for d in run_dirs:
        file = os.path.join(d, "enchanted_datapoint.csv")
        dfs.append(pd.read_csv(file))
    return pd.concat(dfs)

read_summary

read_summary(filename='enchanted_dataset')

Reads the summary written by write_summary.

Attributes:

Name Type Description
filename str

file to be read

Returns:

Type Description
DataFrame

pd.Dataframe: read dataset

Source code in src/enchanted_surrogates/supervisor/supervisor.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def read_summary(self, filename: str = "enchanted_dataset") -> pd.DataFrame:
    """
    Reads the summary written by write_summary.

    Attributes:
        filename (str): file to be read

    Returns:
        pd.Dataframe: read dataset
    """
    if (
        self.args.supervisor
        and self.args.supervisor.get("summary_datatype") == "parquet"
    ):
        file = os.path.join(self.base_run_dir, f"{filename}.parquet")
        if os.path.exists(file):
            return pd.read_parquet(
                file,
                engine="pyarrow",
            )
        return pd.DataFrame()

    file = os.path.join(self.base_run_dir, f"{filename}.csv")
    if os.path.exists(file):
        return pd.read_csv(os.path.join(self.base_run_dir, f"{filename}.csv"))

    return pd.DataFrame()

start

start()

Main function of the supervisor. Starts the simulation process. Currently is the only function, that is accessed outside of supervisor.py. Gathers samples and paths, and gives them to executor. After all processes are finished, creates summary file.

Source code in src/enchanted_surrogates/supervisor/supervisor.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def start(self):
    """
    Main function of the supervisor. Starts the simulation process. Currently
    is the only function, that is accessed outside of supervisor.py.
    Gathers samples and paths, and gives them to executor. After all processes
    are finished, creates summary file.
    """

    log.info("Starting runs...")

    if self.local_storage:
        real_run_dir = self.local_storage
    else:
        real_run_dir = self.base_run_dir

    last_complete_dataset = pd.DataFrame()

    for depth, group in enumerate(self.groups):
        batch_number = 0
        batch_dataset = pd.DataFrame()

        # Restore run state from previous data, if needed and in correct position of the loops
        if self.previous_run_data:
            if depth < self.previous_run_data.depth:
                continue

            if depth == self.previous_run_data.depth:
                batch_number = self.previous_run_data.batch_number + 1

                batch_dataset = self.read_summary()
                last_complete_dataset = self.read_summary(
                    "last_complete_enchanted_dataset"
                )
                group.sampler.register_future(last_complete_dataset)

        while group.sampler.has_budget:
            samples = group.sampler.get_next_samples()

            # Merge parameter names for nesting. On first depth run, expanded=samples
            expanded = []
            if not last_complete_dataset.empty:
                for parent in last_complete_dataset.to_dict(orient="records"):
                    for sample in samples:
                        expanded.append({**parent, **sample})
            else:
                expanded = samples

            # Create run directories named by depth, batch and sample numbers
            run_dirs = [
                os.path.join(real_run_dir, "data", f"d{depth}_b{batch_number}_r{i}")
                for i in range(len(expanded))
            ]

            group.executor.execute(zip(run_dirs, expanded), group.sampler)

            # Wait processes of current batch to complete
            self.wait_batch_dirs(run_dirs)

            # Then the files in this batch should be saved into summary files
            df_batch = self.load_batch_to_df(run_dirs)
            batch_dataset = pd.concat([batch_dataset, df_batch])
            group.sampler.register_future(batch_dataset)

            run_data = RunData(
                batch_number=batch_number,
                depth=depth,
                submitted_samples=group.sampler.submitted,
            )
            run_data.save(self.previous_run_file)

            # Create summary csv or parquet file
            self.write_summary(batch_dataset)

            self.fetch_from_local_storage()

            # Clean unwanted files
            self.delete_unwanted_files(self.save_files_arg, real_run_dir)

            batch_number += 1

        # Update data rows for next nesting level
        last_complete_dataset = batch_dataset.copy()

        # Create a summary file with last_complete_dataset for nesting
        if depth < len(self.groups) - 1:
            self.write_summary(
                last_complete_dataset, "last_complete_enchanted_dataset"
            )

        self.fetch_from_local_storage()

    # Create HDF5 file by default
    if not hasattr(self.args, "storage") or self.args.storage.get("type") != "None":
        self.create_hdf5(last_complete_dataset)

    # Clean unwanted files
    self.delete_unwanted_files(self.save_files_arg, real_run_dir)

    # Clean run_dirs
    print("Shutting down scheduler and workers...")
    for group in self.groups:
        group.executor.clean()

wait_all_processes

wait_all_processes(name_filter=None)

Waits in while loop until all simulations are done. Loop is broken when all_processes_done returns true. Checks condition once in second.

Parameters:

Name Type Description Default
filter str or None

Optional filter used to limit waiting to run directories containing this text. If None (default), all run directories are waited.

required
Source code in src/enchanted_surrogates/supervisor/supervisor.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
def wait_all_processes(self, name_filter=None):
    """
    Waits in while loop until all simulations are done. Loop is broken
    when all_processes_done returns true. Checks condition once in
    second.

    Args:
        filter (str or None): Optional filter used to limit waiting to run directories
            containing this text. If None (default), all run directories are waited.
    """

    while True:
        if self.all_processes_done(name_filter):
            break
        sleep(1)

wait_batch_dirs

wait_batch_dirs(run_dirs)

Waits for batch_dirs_done function to return True

Attributes:

Name Type Description
run_dirs list[str]

List of running directories within the batch

Source code in src/enchanted_surrogates/supervisor/supervisor.py
453
454
455
456
457
458
459
460
461
def wait_batch_dirs(self, run_dirs: list[str]):
    """
    Waits for batch_dirs_done function to return True

    Attributes:
        run_dirs (list[str]): List of running directories within the batch
    """
    while not self.batch_dirs_done(run_dirs):
        sleep(1)

write_summary

write_summary(dataset, filename='enchanted_dataset')

Writes a summary of dataset to base_run_dir/filename This functionality is used within the start function to enable seamless sampling.

Attributes:

Name Type Description
dataset DataFrame

dataset to be written

filename str

filename for the written file

Source code in src/enchanted_surrogates/supervisor/supervisor.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def write_summary(self, dataset: pd.DataFrame, filename: str = "enchanted_dataset"):
    """
    Writes a summary of dataset to base_run_dir/filename
    This functionality is used within the start function to
    enable seamless sampling.

    Attributes:
        dataset (pd.DataFrame): dataset to be written
        filename (str): filename for the written file
    """
    if (
        self.args.supervisor
        and self.args.supervisor.get("summary_datatype") == "parquet"
    ):
        dataset.to_parquet(
            os.path.join(self.base_run_dir, f"{filename}.parquet"),
            engine="pyarrow",
            index=False,
        )
    else:
        dataset.to_csv(
            os.path.join(self.base_run_dir, f"{filename}.csv"), index=False
        )