diff --git a/src/fastoad/io/configuration/configuration.py b/src/fastoad/io/configuration/configuration.py index 6ee29a396..7854120ab 100644 --- a/src/fastoad/io/configuration/configuration.py +++ b/src/fastoad/io/configuration/configuration.py @@ -21,17 +21,15 @@ from importlib.resources import open_text from typing import Dict -import numpy as np import openmdao.api as om import tomlkit from jsonschema import validate from ruamel.yaml import YAML from fastoad._utils.files import make_parent_dir -from fastoad.io import DataFile, IVariableIOFormatter +from fastoad.io import IVariableIOFormatter from fastoad.module_management.service_registry import RegisterOpenMDAOSystem, RegisterSubmodel from fastoad.openmdao.problem import FASTOADProblem -from fastoad.openmdao.variables import VariableList from . import resources from .exceptions import ( FASTConfigurationBadOpenMDAOInstructionError, @@ -115,6 +113,10 @@ def get_problem(self, read_inputs: bool = False, auto_scaling: bool = False) -> problem = FASTOADProblem() self._build_model(problem) + + if self._configuration_modifier: + self._configuration_modifier.modify(problem) + problem.input_file_path = self.input_file_path problem.output_file_path = self.output_file_path @@ -132,9 +134,6 @@ def get_problem(self, read_inputs: bool = False, auto_scaling: bool = False) -> if read_inputs: self._add_design_vars(problem.model, auto_scaling) - if self._configuration_modifier: - self._configuration_modifier.modify(problem) - return problem def load(self, conf_file): @@ -213,33 +212,7 @@ def write_needed_inputs( not provided, expected format will be the default one. """ problem = self.get_problem(read_inputs=False) - problem.setup() - variables = DataFile(self.input_file_path, load_data=False) - - unconnected_inputs = VariableList.from_problem( - problem, - use_initial_values=True, - get_promoted_names=True, - promoted_only=True, - io_status="inputs", - ) - - variables.update( - unconnected_inputs, - add_variables=True, - ) - if source_file_path: - ref_vars = DataFile(source_file_path, formatter=source_formatter) - variables.update(ref_vars, add_variables=False) - nan_variable_names = [] - for var in variables: - var.is_input = True - # Checking if variables have NaN values - if np.any(np.isnan(var.value)): - nan_variable_names.append(var.name) - if nan_variable_names: - _LOGGER.warning("The following variables have NaN values: %s", nan_variable_names) - variables.save() + problem.write_needed_inputs(source_file_path, source_formatter) def get_optimization_definition(self) -> Dict: """ diff --git a/src/fastoad/openmdao/problem.py b/src/fastoad/openmdao/problem.py index 23f62040a..fdd363a73 100644 --- a/src/fastoad/openmdao/problem.py +++ b/src/fastoad/openmdao/problem.py @@ -11,6 +11,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import logging from dataclasses import dataclass, field from typing import Optional, Tuple @@ -19,7 +20,7 @@ from openmdao.core.constants import _SetupStatus from openmdao.core.system import System -from fastoad.io import DataFile, VariableIO +from fastoad.io import DataFile, IVariableIOFormatter, VariableIO from fastoad.module_management.service_registry import RegisterSubmodel from fastoad.openmdao.validity_checker import ValidityDomainChecker from fastoad.openmdao.variables import Variable, VariableList @@ -27,6 +28,8 @@ from .exceptions import FASTOpenMDAONanInInputFile from ..module_management._bundle_loader import BundleLoader +_LOGGER = logging.getLogger(__name__) # Logger for this module + # Name of IVC that will contain input values INPUT_SYSTEM_NAME = "fastoad_inputs" @@ -92,6 +95,46 @@ def setup(self, *args, **kwargs): self._read_inputs_with_setup_done() BundleLoader().clean_memory() + def write_needed_inputs( + self, source_file_path: str = None, source_formatter: IVariableIOFormatter = None + ): + """ + Writes the input file of the problem using its unconnected inputs. + + Written value of each variable will be taken: + + 1. from input_data if it contains the variable + 2. from defined default values in component definitions + + :param source_file_path: if provided, variable values will be read from it + :param source_formatter: the class that defines format of input file. if + not provided, expected format will be the default one. + """ + self.self_analysis() + + variables = DataFile(self.input_file_path, load_data=False) + + unconnected_inputs = VariableList( + [variable for variable in self._analysis.problem_variables if variable.is_input] + ) + + variables.update( + unconnected_inputs, + add_variables=True, + ) + if source_file_path: + ref_vars = DataFile(source_file_path, formatter=source_formatter) + variables.update(ref_vars, add_variables=False) + nan_variable_names = [] + for var in variables: + var.is_input = True + # Checking if variables have NaN values + if np.any(np.isnan(var.value)): + nan_variable_names.append(var.name) + if nan_variable_names: + _LOGGER.warning("The following variables have NaN values: %s", nan_variable_names) + variables.save() + def write_outputs(self): """ Writes all outputs in the configured output file. diff --git a/tests/integration_tests/oad_process/test_oad_process.py b/tests/integration_tests/oad_process/test_oad_process.py index 2188fbf57..450875ad2 100644 --- a/tests/integration_tests/oad_process/test_oad_process.py +++ b/tests/integration_tests/oad_process/test_oad_process.py @@ -2,7 +2,7 @@ Test module for Overall Aircraft Design process """ # This file is part of FAST-OAD : A framework for rapid Overall Aircraft Design -# Copyright (C) 2023 ONERA & ISAE-SUPAERO +# Copyright (C) 2024 ONERA & ISAE-SUPAERO # FAST is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or @@ -50,12 +50,11 @@ def test_oad_process(cleanup): configurator = FASTOADProblemConfigurator(pth.join(DATA_FOLDER_PATH, "oad_process.yml")) - # Create inputs ref_inputs = pth.join(DATA_FOLDER_PATH, "CeRAS01_legacy.xml") - configurator.write_needed_inputs(ref_inputs) + problem = configurator.get_problem() + problem.write_needed_inputs(ref_inputs) + problem.read_inputs() - # Create problems with inputs - problem = configurator.get_problem(read_inputs=True) problem.setup() problem.run_model() problem.write_outputs()