Source code for mmgp.utils

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
#
import os
import shutil

import yaml
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(
    format='[%(asctime)s:%(levelname)s:%(filename)s:%(funcName)s(%(lineno)d)]:%(message)s',
    level=logging.INFO)

from plaid.containers.dataset import Dataset
from plaid.problem_definition import ProblemDefinition

[docs] expected_structure = { 'case_name': str, 'init_dataset_location': str, 'generated_data_folder': str, 'zone_name': str, 'base_name': str, 'train_set': str, 'test_set': str, 'common_mesh_index': int, 'verbose': bool, 'morphing': { 'algo': str, 'options': str, }, 'dimensionality_reduction': { 'input_coord_fields': { 'algo': str, 'options': { 'number_of_modes': int, 'correlation_type': str, }, }, 'output_coord_fields': { 'algo': str, 'options': { 'number_of_modes': int, 'correlation_type': str, }, }, }, 'regression': { 'reference_regressor': int, 'uncertainties': bool, 'number_Monte_Carlo_samples': int, 'algo': str, 'options': { 'kernel': str, 'kernel_options': { 'nu': float, }, 'optim': str, 'num_restarts': int, 'anisotropic': bool, 'random_state': (int, None), }, }, }
"""This dictionary defines the expected structure of a YAML configuration file. It specifies the keys that should be present in the configuration file and their expected data types. """
[docs] def load_backend(backend_module :str): try: import importlib module = importlib.import_module(backend_module) return module.Regressor except ImportError: raise ImportError(f"Failed to load {backend_module}")
[docs] def validate_configs(configs: dict, v_structure: dict, depth_accumulator: str = ""): """Recursively validate the structure and data types of a configuration dictionary. This function checks whether a given dictionary `configs` conforms to the expected structure and data types defined in the `v_structure` dictionary. Args: configs (dict): The configuration dictionary to be validated. v_structure (dict): The expected structure. depth_accumulator (str): The path from the root of the concerned. Returns: bool: True if the `configs` dictionary matches the expected structure and data types; False otherwise. """ for key, value in v_structure.items(): # Missing key if key not in configs: # TODO: for optional parts only #if isinstance(value, tuple): # continue return False, depth_accumulator + f"missing {key = }" # TODO: temporary -> skip options if key == "options": continue # Comparing type if isinstance(value, type): if not isinstance(configs[key], value): return False, depth_accumulator + f"bad type '{configs[key]}' for {key = }. Expected type was '{value}'" # TODO: for optional parts only #elif isinstance(value, tuple): # if not isinstance(configs[key], value[0]): # return False, depth_accumulator + f"bad type '{configs[key]}' for {key = }. Expected type was '{value[0]}'" # Recursively checks dicts elif isinstance(value, dict): test_passed, reason = validate_configs(configs[key], v_structure[key], depth_accumulator + f"{key} -> ") if not test_passed: return False, depth_accumulator + reason else: return False, depth_accumulator + f"unknown type '{value}' for {key = }" if len(configs) > len(v_structure): disjoint_elements = set(configs.keys()) - set(configs.keys()) logger.warning(f"Too much information provided, these elements will not be taken into account: {depth_accumulator}{disjoint_elements}") return True, ""
[docs] def read_configuration(config_file: str) -> dict[str, str]: """Read and process the YAML configuration file specified during class initialization. It then defines the configuration details inside the class. Args: config_file (str): The path to the YAML configuration file. Returns: dict[str,str]: A dictionary containing configuration information. Example: .. code-block:: python configuration = read_configuration('config.yaml') Attention: - Ensure that the configuration file exists at the specified path. - The method assumes that the configuration file is in YAML format. """ if os.path.isfile(config_file): with open(config_file, 'r') as file: configuration = yaml.safe_load(file) else: # pragma: no cover logger.warning( "file " + config_file + " not found, unread configuration") return test_passed, reason = validate_configs(configuration, expected_structure) if not test_passed: raise TypeError(reason) # affecting default values if "verbose" not in configuration: configuration["verbose"] = True if "train_set" not in configuration: configuration["train_set"] = "train" if "test_set" not in configuration: configuration["test_set"] = "test" return configuration
[docs] def read_problem(configuration:dict) -> ProblemDefinition: """Read and process problem definition data from a directory. Args: configuration (dict): A dictionary containing configuration information. Returns: ProblemDefinition: An instance of the `ProblemDefinition` class representing the problem. Note: The fields: "init_dataset_location", "zone_name" and "base_name" of the configuration file will be used to get and set the 'problem definition'. Attention: Ensure that the initial dataset and problem definition files exist at the specified path. """ dataset = Dataset() dataset._load_from_dir_( configuration['init_dataset_location'] + os.sep + "dataset", ids=[0], verbose=False) field_names = dataset.get_field_names( zone_name=configuration['zone_name'], base_name=configuration['base_name']) scalar_names = dataset.get_scalar_names() problem = ProblemDefinition() problem._load_from_dir_( configuration['init_dataset_location'] + os.sep + "problem_definition") problem.in_scalars_names = problem.filter_input_scalars_names(scalar_names) problem.out_scalars_names = problem.filter_output_scalars_names(scalar_names) problem.out_fields_names = problem.filter_output_fields_names(field_names) return problem
[docs] def reset_folder(path: str) -> None: # pragma: no cover """Reset or create a folder by deleting its contents. Args: path (str): The path to the folder to be reset or created. Attention: Use this function with caution, as it permanently deletes the contents of the specified folder. """ if os.path.isdir(path): shutil.rmtree(path) os.makedirs(path)
[docs] def remove_file(path: str) -> None: # pragma: no cover """Remove a file if it exists. Args: path (str): The path to the file to be removed. Attention: Use this function with caution, as it permanently deletes the contents of the specified folder. """ if os.path.isfile(path): os.remove(path)