Skip to content

Commit

Permalink
Refactoring.
Browse files Browse the repository at this point in the history
write_needed_inputs() back in FASTOADProblem

When not using high-level API, it will allow to have problem analysis done once when doing write_needed_inputs() and not in subsequent operations.
  • Loading branch information
christophe-david committed Jan 15, 2024
1 parent 2b8c019 commit 9f9c86c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 39 deletions.
39 changes: 6 additions & 33 deletions src/fastoad/io/configuration/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@
from importlib.resources import open_text
from typing import Dict

import numpy as np
import openmdao.api as om
import tomlkit
from jsonschema import validate
from ruamel.yaml import YAML

from fastoad._utils.files import make_parent_dir
from fastoad.io import DataFile, IVariableIOFormatter
from fastoad.io import IVariableIOFormatter
from fastoad.module_management.service_registry import RegisterOpenMDAOSystem, RegisterSubmodel
from fastoad.openmdao.problem import FASTOADProblem
from fastoad.openmdao.variables import VariableList
from . import resources
from .exceptions import (
FASTConfigurationBadOpenMDAOInstructionError,
Expand Down Expand Up @@ -115,6 +113,10 @@ def get_problem(self, read_inputs: bool = False, auto_scaling: bool = False) ->

problem = FASTOADProblem()
self._build_model(problem)

if self._configuration_modifier:
self._configuration_modifier.modify(problem)

problem.input_file_path = self.input_file_path
problem.output_file_path = self.output_file_path

Expand All @@ -132,9 +134,6 @@ def get_problem(self, read_inputs: bool = False, auto_scaling: bool = False) ->
if read_inputs:
self._add_design_vars(problem.model, auto_scaling)

if self._configuration_modifier:
self._configuration_modifier.modify(problem)

return problem

def load(self, conf_file):
Expand Down Expand Up @@ -213,33 +212,7 @@ def write_needed_inputs(
not provided, expected format will be the default one.
"""
problem = self.get_problem(read_inputs=False)
problem.setup()
variables = DataFile(self.input_file_path, load_data=False)

unconnected_inputs = VariableList.from_problem(
problem,
use_initial_values=True,
get_promoted_names=True,
promoted_only=True,
io_status="inputs",
)

variables.update(
unconnected_inputs,
add_variables=True,
)
if source_file_path:
ref_vars = DataFile(source_file_path, formatter=source_formatter)
variables.update(ref_vars, add_variables=False)
nan_variable_names = []
for var in variables:
var.is_input = True
# Checking if variables have NaN values
if np.any(np.isnan(var.value)):
nan_variable_names.append(var.name)
if nan_variable_names:
_LOGGER.warning("The following variables have NaN values: %s", nan_variable_names)
variables.save()
problem.write_needed_inputs(source_file_path, source_formatter)

def get_optimization_definition(self) -> Dict:
"""
Expand Down
45 changes: 44 additions & 1 deletion src/fastoad/openmdao/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import logging
from dataclasses import dataclass, field
from typing import Optional, Tuple

Expand All @@ -19,14 +20,16 @@
from openmdao.core.constants import _SetupStatus
from openmdao.core.system import System

from fastoad.io import DataFile, VariableIO
from fastoad.io import DataFile, IVariableIOFormatter, VariableIO
from fastoad.module_management.service_registry import RegisterSubmodel
from fastoad.openmdao.validity_checker import ValidityDomainChecker
from fastoad.openmdao.variables import Variable, VariableList
from ._utils import get_problem_copy_without_mpi
from .exceptions import FASTOpenMDAONanInInputFile
from ..module_management._bundle_loader import BundleLoader

_LOGGER = logging.getLogger(__name__) # Logger for this module

# Name of IVC that will contain input values
INPUT_SYSTEM_NAME = "fastoad_inputs"

Expand Down Expand Up @@ -92,6 +95,46 @@ def setup(self, *args, **kwargs):
self._read_inputs_with_setup_done()
BundleLoader().clean_memory()

def write_needed_inputs(
self, source_file_path: str = None, source_formatter: IVariableIOFormatter = None
):
"""
Writes the input file of the problem using its unconnected inputs.
Written value of each variable will be taken:
1. from input_data if it contains the variable
2. from defined default values in component definitions
:param source_file_path: if provided, variable values will be read from it
:param source_formatter: the class that defines format of input file. if
not provided, expected format will be the default one.
"""
self.self_analysis()

variables = DataFile(self.input_file_path, load_data=False)

unconnected_inputs = VariableList(
[variable for variable in self._analysis.problem_variables if variable.is_input]
)

variables.update(
unconnected_inputs,
add_variables=True,
)
if source_file_path:
ref_vars = DataFile(source_file_path, formatter=source_formatter)
variables.update(ref_vars, add_variables=False)
nan_variable_names = []
for var in variables:
var.is_input = True
# Checking if variables have NaN values
if np.any(np.isnan(var.value)):
nan_variable_names.append(var.name)
if nan_variable_names:
_LOGGER.warning("The following variables have NaN values: %s", nan_variable_names)
variables.save()

def write_outputs(self):
"""
Writes all outputs in the configured output file.
Expand Down
9 changes: 4 additions & 5 deletions tests/integration_tests/oad_process/test_oad_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Test module for Overall Aircraft Design process
"""
# This file is part of FAST-OAD : A framework for rapid Overall Aircraft Design
# Copyright (C) 2023 ONERA & ISAE-SUPAERO
# Copyright (C) 2024 ONERA & ISAE-SUPAERO
# FAST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
Expand Down Expand Up @@ -50,12 +50,11 @@ def test_oad_process(cleanup):

configurator = FASTOADProblemConfigurator(pth.join(DATA_FOLDER_PATH, "oad_process.yml"))

# Create inputs
ref_inputs = pth.join(DATA_FOLDER_PATH, "CeRAS01_legacy.xml")
configurator.write_needed_inputs(ref_inputs)
problem = configurator.get_problem()
problem.write_needed_inputs(ref_inputs)
problem.read_inputs()

# Create problems with inputs
problem = configurator.get_problem(read_inputs=True)
problem.setup()
problem.run_model()
problem.write_outputs()
Expand Down

0 comments on commit 9f9c86c

Please sign in to comment.