# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
#
import os
import shutil
import yaml
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(
format='[%(asctime)s:%(levelname)s:%(filename)s:%(funcName)s(%(lineno)d)]:%(message)s',
level=logging.INFO)
from plaid.containers.dataset import Dataset
from plaid.problem_definition import ProblemDefinition
[docs]
expected_structure = {
'case_name': str,
'init_dataset_location': str,
'generated_data_folder': str,
'zone_name': str,
'base_name': str,
'train_set': str,
'test_set': str,
'common_mesh_index': int,
'verbose': bool,
'morphing': {
'algo': str,
'options': str,
},
'dimensionality_reduction': {
'input_coord_fields': {
'algo': str,
'options': {
'number_of_modes': int,
'correlation_type': str,
},
},
'output_coord_fields': {
'algo': str,
'options': {
'number_of_modes': int,
'correlation_type': str,
},
},
},
'regression': {
'reference_regressor': int,
'uncertainties': bool,
'number_Monte_Carlo_samples': int,
'algo': str,
'options': {
'kernel': str,
'kernel_options': {
'nu': float,
},
'optim': str,
'num_restarts': int,
'anisotropic': bool,
'random_state': (int, None),
},
},
}
"""This dictionary defines the expected structure of a YAML configuration file.
It specifies the keys that should be present in the configuration file and their
expected data types.
"""
[docs]
def load_backend(backend_module :str):
try:
import importlib
module = importlib.import_module(backend_module)
return module.Regressor
except ImportError:
raise ImportError(f"Failed to load {backend_module}")
[docs]
def validate_configs(configs: dict, v_structure: dict, depth_accumulator: str = ""):
"""Recursively validate the structure and data types of a configuration dictionary.
This function checks whether a given dictionary `configs` conforms to the
expected structure and data types defined in the `v_structure` dictionary.
Args:
configs (dict): The configuration dictionary to be validated.
v_structure (dict): The expected structure.
depth_accumulator (str): The path from the root of the concerned.
Returns:
bool: True if the `configs` dictionary matches the expected structure
and data types; False otherwise.
"""
for key, value in v_structure.items():
# Missing key
if key not in configs:
# TODO: for optional parts only
#if isinstance(value, tuple):
# continue
return False, depth_accumulator + f"missing {key = }"
# TODO: temporary -> skip options
if key == "options":
continue
# Comparing type
if isinstance(value, type):
if not isinstance(configs[key], value):
return False, depth_accumulator + f"bad type '{configs[key]}' for {key = }. Expected type was '{value}'"
# TODO: for optional parts only
#elif isinstance(value, tuple):
# if not isinstance(configs[key], value[0]):
# return False, depth_accumulator + f"bad type '{configs[key]}' for {key = }. Expected type was '{value[0]}'"
# Recursively checks dicts
elif isinstance(value, dict):
test_passed, reason = validate_configs(configs[key], v_structure[key], depth_accumulator + f"{key} -> ")
if not test_passed:
return False, depth_accumulator + reason
else:
return False, depth_accumulator + f"unknown type '{value}' for {key = }"
if len(configs) > len(v_structure):
disjoint_elements = set(configs.keys()) - set(configs.keys())
logger.warning(f"Too much information provided, these elements will not be taken into account: {depth_accumulator}{disjoint_elements}")
return True, ""
[docs]
def read_configuration(config_file: str) -> dict[str, str]:
"""Read and process the YAML configuration file specified during class initialization.
It then defines the configuration details inside the class.
Args:
config_file (str): The path to the YAML configuration file.
Returns:
dict[str,str]: A dictionary containing configuration information.
Example:
.. code-block:: python
configuration = read_configuration('config.yaml')
Attention:
- Ensure that the configuration file exists at the specified path.
- The method assumes that the configuration file is in YAML format.
"""
if os.path.isfile(config_file):
with open(config_file, 'r') as file:
configuration = yaml.safe_load(file)
else: # pragma: no cover
logger.warning(
"file " +
config_file +
" not found, unread configuration")
return
test_passed, reason = validate_configs(configuration, expected_structure)
if not test_passed:
raise TypeError(reason)
# affecting default values
if "verbose" not in configuration:
configuration["verbose"] = True
if "train_set" not in configuration:
configuration["train_set"] = "train"
if "test_set" not in configuration:
configuration["test_set"] = "test"
return configuration
[docs]
def read_problem(configuration:dict) -> ProblemDefinition:
"""Read and process problem definition data from a directory.
Args:
configuration (dict): A dictionary containing configuration information.
Returns:
ProblemDefinition: An instance of the `ProblemDefinition` class representing the problem.
Note:
The fields: "init_dataset_location", "zone_name" and "base_name" of the configuration file will be used to get and set the 'problem definition'.
Attention:
Ensure that the initial dataset and problem definition files exist at the specified path.
"""
dataset = Dataset()
dataset._load_from_dir_(
configuration['init_dataset_location'] +
os.sep +
"dataset",
ids=[0],
verbose=False)
field_names = dataset.get_field_names(
zone_name=configuration['zone_name'],
base_name=configuration['base_name'])
scalar_names = dataset.get_scalar_names()
problem = ProblemDefinition()
problem._load_from_dir_(
configuration['init_dataset_location'] +
os.sep +
"problem_definition")
problem.in_scalars_names = problem.filter_input_scalars_names(scalar_names)
problem.out_scalars_names = problem.filter_output_scalars_names(scalar_names)
problem.out_fields_names = problem.filter_output_fields_names(field_names)
return problem
[docs]
def reset_folder(path: str) -> None: # pragma: no cover
"""Reset or create a folder by deleting its contents.
Args:
path (str): The path to the folder to be reset or created.
Attention:
Use this function with caution, as it permanently deletes the contents of the specified folder.
"""
if os.path.isdir(path):
shutil.rmtree(path)
os.makedirs(path)
[docs]
def remove_file(path: str) -> None: # pragma: no cover
"""Remove a file if it exists.
Args:
path (str): The path to the file to be removed.
Attention:
Use this function with caution, as it permanently deletes the contents of the specified folder.
"""
if os.path.isfile(path):
os.remove(path)
[docs]
def print_setting(configuration:dict, problem:ProblemDefinition) -> None:
"""Return a string representation of the settings in a readable format (pretty print).
"""
from Muscat.Helpers.TextFormatHelper import TFormat
def print_dict(str_repr, d, depth=0):
maxl = max([len(str(v)) for v in d.keys()])
if depth > 0:
TFormat.II()
TFormat.II()
TFormat.II()
for cat, info in d.items():
if isinstance(info, dict):
str_repr += TFormat.GetIndent() + TFormat.InBlue(cat.ljust(maxl)) + " :\n"
str_repr = print_dict(str_repr, info, depth + 1)
else:
str_repr += TFormat.GetIndent() + TFormat.InBlue(cat.ljust(maxl + 1)) + \
": " + str(info) + "\n"
if depth > 0:
TFormat.DI()
TFormat.DI()
TFormat.DI()
return str_repr
TFormat.Reset()
str_repr = ""
if configuration is not None:
str_repr += TFormat.InGreen(TFormat.Center("Configuration")) + "\n"
str_repr = print_dict(str_repr, configuration)
if problem is not None:
str_repr += TFormat.InGreen(TFormat.Center("Problem definition")) + "\n"
str_repr += TFormat.GetIndent() + TFormat.InBlue("split names".ljust(20)) + \
" : " + str(list(problem.get_split().keys())) + "\n"
str_repr += TFormat.GetIndent() + TFormat.InBlue("input scalar names".ljust(20)) + \
" : " + str(problem.in_scalars_names) + "\n"
str_repr += TFormat.GetIndent() + TFormat.InBlue("output scalar names".ljust(20)) + \
" : " + str(problem.out_scalars_names) + "\n"
str_repr += TFormat.GetIndent() + TFormat.InBlue("output field names".ljust(20)) + \
" : " + str(problem.out_fields_names) + "\n"
print(str_repr)