Skip to content

Commit

Permalink
Simplify domain traversal logic
Browse files Browse the repository at this point in the history
This commit simplifies how directories are identified and traversed when
`domain.init()` is called. Most importantly, we no longer load a module
if it has been loaded already.
  • Loading branch information
subhashb committed Jun 18, 2024
1 parent 5a1ba80 commit eb2aba2
Showing 1 changed file with 45 additions and 57 deletions.
102 changes: 45 additions & 57 deletions src/protean/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,19 @@ def _traverse(self):
import os
import pathlib

# Directory containing the domain file
root_dir = pathlib.PurePath(pathlib.Path(self.root_path).resolve()).parent
path = pathlib.Path(root_dir) # Resolve the domain file's directory
system_folder_path = (
path.parent
) # Get the directory of the domain file to traverse from

# Parent Directory of the directory containing the domain file
#
# We need this to decipher paths from the root. For example,
# say the domain file is in a directory called `test13`, and
# we are traversing a subdirectory `auth` inside `test13`.
# We need to resolve the module for files in the `auth` directory
# as `test13.auth`.
#
# This makes relative imports possible
system_folder_path = pathlib.Path(root_dir).parent

logger.debug(f"Loading domain from {root_dir}...")

Expand All @@ -233,8 +241,9 @@ def _traverse(self):
and name not in ["__pycache__"]
]

directories_to_traverse = [str(root_dir)] # Include root directory

# Identify subdirectories that have a toml file
subdirectories_to_traverse = []
files_to_check = ["domain.toml", ".domain.toml", "pyproject.toml"]
for subdirectory in subdirectories:
subdirectory_path = os.path.join(root_dir, subdirectory)
Expand All @@ -243,59 +252,38 @@ def _traverse(self):
for file in files_to_check
if os.path.isfile(os.path.join(subdirectory_path, file))
):
subdirectories_to_traverse.append(subdirectory_path)

# Traverse root directory
for filename in os.listdir(root_dir):
full_file_path = os.path.join(root_dir, filename)
if (
os.path.isfile(full_file_path)
and os.path.splitext(filename)[1] == ".py"
and full_file_path != self.root_path
):
spec = importlib.util.spec_from_file_location(
filename, os.path.join(root_dir, filename)
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

logger.debug(f"Loaded {filename}")

# Traverse subdirectories recursively
for directory in subdirectories_to_traverse:
for root, _, files in os.walk(directory):
if pathlib.PurePath(root).name not in ["__pycache__"]:
package_path = root[len(str(system_folder_path)) + 1 :]
module_name = package_path.replace(os.sep, ".")

for file in files:
file_base_name = os.path.basename(file)

# Ignore if the file is not a python file
if os.path.splitext(file_base_name)[1] != ".py":
continue

# Construct the module path to import from
if file_base_name != "__init__":
sub_module_name = os.path.splitext(file_base_name)[0]
file_module_name = module_name + "." + sub_module_name
else:
file_module_name = module_name
full_file_path = os.path.join(root, file)

try:
if (
full_file_path != self.root_path
): # Don't load the domain file itself again
spec = importlib.util.spec_from_file_location(
file_module_name, full_file_path
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
directories_to_traverse.append(subdirectory_path)

# Traverse directories one by one
for directory in directories_to_traverse:
for filename in os.listdir(directory):
package_path = directory[len(str(system_folder_path)) + 1 :]
module_name = package_path.replace(os.sep, ".")
full_file_path = os.path.join(directory, filename)

if (
os.path.isfile(full_file_path)
and os.path.splitext(filename)[1] == ".py"
and full_file_path != self.root_path
):
# Construct the module path to import from
if filename != "__init__.py":
sub_module_name = os.path.splitext(filename)[0]
file_module_name = module_name + "." + sub_module_name
else:
file_module_name = module_name
full_file_path = os.path.join(root_dir, filename)

spec = importlib.util.spec_from_file_location(
file_module_name, os.path.join(directory, filename)
)
module = importlib.util.module_from_spec(spec)

# Do not load module again if it has already been loaded
if module.__name__ not in sys.modules:
spec.loader.exec_module(module)

logger.debug(f"Loaded {file_module_name}")
except ModuleNotFoundError as exc:
logger.error(f"Error while loading a module: {exc}")
logger.debug(f"Loaded {filename}")

def _initialize(self):
"""Initialize domain dependencies and adapters."""
Expand Down

0 comments on commit eb2aba2

Please sign in to comment.