Skip to content

Commit

Permalink
Merge pull request #198 from visualDust/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
visualDust authored Oct 16, 2024
2 parents 4b1c795 + 54b1d5b commit 737abc0
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 50 deletions.
22 changes: 22 additions & 0 deletions neetbox/config/user/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,31 @@
_GLOBAL_CONFIG = {
MACHINE_ID_KEY: str(uuid4()),
"vault": get_create_neetbox_data_directory(),
"bypass-db-version-check": False,
}

_GLOBAL_CONFIG_FILE_NAME = f"neetbox.global.toml"


def update_dict_recursively_on_missing_keys(A, B):
"""
Update dictionary B with keys from dictionary A. Add missing keys from A to B,
but do not overwrite existing keys in B. Handles nested dictionaries recursively.
"""
missed_keys = []
for key, value in A.items():
if key not in B:
missed_keys.append(key)
B[key] = value
else:
if isinstance(value, dict) and isinstance(B[key], dict):
missed_keys += update_dict_recursively_on_missing_keys(value, B[key])
else:
# Do not modify B[key] if it already exists
pass
return missed_keys


def overwrite_create_local(config: dict):
neetbox_config_dir = get_create_neetbox_config_directory()
config_file_path = os.path.join(neetbox_config_dir, _GLOBAL_CONFIG_FILE_NAME)
Expand All @@ -40,7 +60,9 @@ def read_create_local():
# read local file
user_cfg = check_read_toml(config_file_path)
assert user_cfg
update_dict_recursively_on_missing_keys(_GLOBAL_CONFIG, user_cfg)
_GLOBAL_CONFIG.update(user_cfg)
overwrite_create_local(_GLOBAL_CONFIG)


def set(key, value):
Expand Down
17 changes: 13 additions & 4 deletions neetbox/server/db/project/_project_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ def __new__(cls, project_id: str = None, path: str = None, **kwargs) -> "Project
)
_db_file_version = new_dbc.fetch_db_version(NEETBOX_VERSION)
if NEETBOX_VERSION != _db_file_version:
logger.warn(
f"History file version not match: reading from version {_db_file_version} with neetbox version {NEETBOX_VERSION}"
)
if get_global_config("bypass-db-version-check"):
logger.warn(
f"History file version not match: reading from version {_db_file_version} with neetbox version {NEETBOX_VERSION}"
)
else:
raise RuntimeError(
f"History file version not match: reading from version {_db_file_version} with neetbox version {NEETBOX_VERSION}. If you want to bypass this check, set 'bypass-db-version-check' to True in global config. This may cause unexpected behavior."
)
cls._path2dbc[path] = new_dbc
manager.current[project_id] = new_dbc
new_dbc.project_id = project_id
Expand Down Expand Up @@ -425,7 +430,11 @@ def read_blob(
def load_db_of_path(cls, path):
if not os.path.isfile(path):
raise RuntimeError(f"{path} is not a file")
conn = ProjectDB(path=path)
try:
conn = ProjectDB(path=path)
except Exception as e:
logger.err(f"failed to load db from {path} cause {e}")
return None
return conn

@classmethod
Expand Down
110 changes: 65 additions & 45 deletions neetbox/utils/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

class _RegEndpoint:
def __init__(self, what, tags=None):
"""Generate a massive type which contains both the regietered object and it's tags
"""Generate a massive type which contains both the registered object and its tags
Args:
what (_type_): The object being registered
tags (_type_, optional): The tags. Defaults to None.
what (Any): The object being registered
tags (Optional[Union[str, Sequence[str]]], optional): The tags. Defaults to None.
"""
self.what = what
self.tags = tags
Expand All @@ -26,8 +26,8 @@ def __str__(self) -> str:


def _tags_match(search_tags, in_tags) -> bool:
# check if all tags in f_tags are listed in s_tags
if type(search_tags) is not list:
"""Check if all tags in search_tags are listed in in_tags."""
if not isinstance(search_tags, list):
search_tags = [search_tags]
for _t in search_tags:
if _t not in in_tags:
Expand All @@ -36,14 +36,13 @@ def _tags_match(search_tags, in_tags) -> bool:


class Registry(dict):

"""Register Helper Class
A Register is a 'dict[str:any]'
Registers are stored in a pool of type dict[str:Register]
A Registry is a 'dict[str:any]'
Registries are stored in a pool of type dict[str:Registry]
"""

# class level
_registry_pool: Dict[str, "Registry"] = dict() # all registeres are stored here
# Class-level registry pool
_registry_pool: Dict[str, "Registry"] = dict()

def __new__(cls, name: str) -> "Registry":
if name in cls._registry_pool:
Expand All @@ -61,23 +60,20 @@ def _register(
self,
what: Any,
name: Optional[str] = None,
overwrite: Union[bool, Callable] = lambda x: x + f"_{uuid4()}",
overwrite: Union[bool, Callable[[str], str]] = lambda x: x + f"_{uuid4()}",
tags: Optional[Union[str, Sequence[str]]] = None,
):
# if not (inspect.isfunction(what) or inspect.isclass(what)):
# logger.warn(f"Registering {type(what)}, which is not a class or a callable.")
name = name or what.__name__
if type(tags) is str:
if isinstance(tags, str):
tags = [tags]
_endp = _RegEndpoint(what, tags)
if name in self.keys():
if isinstance(overwrite, Callable):
if callable(overwrite):
name = overwrite(name)
elif overwrite == True:
elif overwrite is True:
pass
else:
raise RuntimeError(f"Unknown overwrite type.")

raise RuntimeError("Unknown overwrite type.")
self[name] = _endp
else:
self[name] = _endp
Expand All @@ -87,10 +83,13 @@ def register(
self,
*,
name: Optional[str] = None,
overwrite: Union[bool, Callable] = lambda x: x + f"_{uuid4()}",
overwrite: Union[bool, Callable[[str], str]] = lambda x: x + f"_{uuid4()}",
tags: Optional[Union[str, Sequence[str]]] = None,
):
return functools.partial(self._register, name=name, overwrite=overwrite, tags=tags)
def decorator(what):
self._register(what, name=name, overwrite=overwrite, tags=tags)
return what
return decorator

@classmethod
def find(
Expand All @@ -99,51 +98,51 @@ def find(
tags: Optional[Union[str, List[str]]] = None,
):
if not name and not tags:
# logger.err(
# ValueError("Please provide at least the name or the tags you want to find."),
# reraise=True,
# )
pass
raise ValueError("Please provide at least the name or the tags you want to find.")
results = []
# filter name
# Filter by name
for reg_name, reg in cls._registry_pool.items():
private_sign = "__"
if not reg_name.startswith(private_sign):
if not name:
results += [(_n, _o) for _n, _o in reg.items(_real_type=False)]
elif name in reg:
results.append((name, reg[name]))

# filter tags
if type(tags) is not list:
# Filter by tags
if not isinstance(tags, list):
tags = [tags]

results = {_name: _endp.what for _name, _endp in results if _tags_match(tags, _endp.tags)}
results = {
_name: _endp.what
for _name, _endp in results
if _tags_match(tags, _endp.tags)
}
return results

def filter(self, tags: Optional[Union[str, Sequence[str]]] = None):
results = {
_name: _endp.what for _name, _endp in self._items() if _tags_match(tags, _endp.tags)
_name: _endp.what
for _name, _endp in self._items()
if _tags_match(tags, _endp.tags)
}
return results

def __getitem__(self, __key: str) -> Any:
_v = self.__dict__[__key]
if type(_v) is _RegEndpoint:
if isinstance(_v, _RegEndpoint):
_v = _v.what
return _v

def get(self, key: str, **kwargs):
if key in self.__dict__:
_v = self.__dict__[key]
if type(_v) is _RegEndpoint:
if isinstance(_v, _RegEndpoint):
_v = _v.what
return _v
else:
if "default" in kwargs:
return kwargs["default"]
else:
raise RuntimeError(f"key {key} not found")
raise RuntimeError(f"Key '{key}' not found")

def __setitem__(self, k, v) -> None:
self.__dict__[k] = v
Expand All @@ -161,29 +160,50 @@ def update(self, *args, **kwargs):
return self.__dict__.update(*args, **kwargs)

def keys(self):
return [_item[0] for _item in self.__dict__.items() if type(_item[1]) is _RegEndpoint]
return [
_item[0]
for _item in self.__dict__.items()
if isinstance(_item[1], _RegEndpoint)
]

def values(self):
return [_item[1].what for _item in self.__dict__.items() if type(_item[1]) is _RegEndpoint]
return [
_item[1].what
for _item in self.__dict__.items()
if isinstance(_item[1], _RegEndpoint)
]

def items(self, _real_type=True):
_legal_items = [_item for _item in self.__dict__.items() if type(_item[1]) is _RegEndpoint]
_legal_items = [
_item
for _item in self.__dict__.items()
if isinstance(_item[1], _RegEndpoint)
]
if _real_type:
_legal_items = [(_k, _v.what) for _k, _v in _legal_items if type(_v) is _RegEndpoint]
_legal_items = [
(_k, _v.what)
for _k, _v in _legal_items
if isinstance(_v, _RegEndpoint)
]
return _legal_items

def _items(self, _real_type=True):
_legal_items = [_item for _item in self.__dict__.items() if type(_item[1]) is _RegEndpoint]
_legal_items = [
_item
for _item in self.__dict__.items()
if isinstance(_item[1], _RegEndpoint)
]
if _real_type:
_legal_items = [(_k, _v) for _k, _v in _legal_items if type(_v) is _RegEndpoint]
_legal_items = [
(_k, _v)
for _k, _v in _legal_items
if isinstance(_v, _RegEndpoint)
]
return _legal_items

def pop(self, *args):
return self.__dict__.pop(*args)

def __cmp__(self, dict_):
return self.__cmp__(self.__dict__, dict_)

def __contains__(self, item):
return item in self.__dict__

Expand All @@ -200,4 +220,4 @@ def __str__(self) -> str:
default=str,
)

__repr__ = __str__
__repr__ = __str__
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "neetbox"
version = "0.4.12"
version = "0.4.13"
description = "Logging/Debugging/Tracing/Managing/Facilitating long running python projects, especially a replacement of tensorboard for deep learning projects"
license = "MIT"
authors = ["VisualDust <[email protected]>", "Lideming <[email protected]>"]
Expand Down

0 comments on commit 737abc0

Please sign in to comment.