diff --git a/neetbox/config/user/__init__.py b/neetbox/config/user/__init__.py index b719e1eb..78e7585e 100644 --- a/neetbox/config/user/__init__.py +++ b/neetbox/config/user/__init__.py @@ -19,11 +19,31 @@ _GLOBAL_CONFIG = { MACHINE_ID_KEY: str(uuid4()), "vault": get_create_neetbox_data_directory(), + "bypass-db-version-check": False, } _GLOBAL_CONFIG_FILE_NAME = f"neetbox.global.toml" +def update_dict_recursively_on_missing_keys(A, B): + """ + Update dictionary B with keys from dictionary A. Add missing keys from A to B, + but do not overwrite existing keys in B. Handles nested dictionaries recursively. + """ + missed_keys = [] + for key, value in A.items(): + if key not in B: + missed_keys.append(key) + B[key] = value + else: + if isinstance(value, dict) and isinstance(B[key], dict): + missed_keys += update_dict_recursively_on_missing_keys(value, B[key]) + else: + # Do not modify B[key] if it already exists + pass + return missed_keys + + def overwrite_create_local(config: dict): neetbox_config_dir = get_create_neetbox_config_directory() config_file_path = os.path.join(neetbox_config_dir, _GLOBAL_CONFIG_FILE_NAME) @@ -40,7 +60,9 @@ def read_create_local(): # read local file user_cfg = check_read_toml(config_file_path) assert user_cfg + update_dict_recursively_on_missing_keys(_GLOBAL_CONFIG, user_cfg) _GLOBAL_CONFIG.update(user_cfg) + overwrite_create_local(_GLOBAL_CONFIG) def set(key, value): diff --git a/neetbox/server/db/project/_project_db.py b/neetbox/server/db/project/_project_db.py index f05fff11..86dd0a41 100644 --- a/neetbox/server/db/project/_project_db.py +++ b/neetbox/server/db/project/_project_db.py @@ -62,9 +62,14 @@ def __new__(cls, project_id: str = None, path: str = None, **kwargs) -> "Project ) _db_file_version = new_dbc.fetch_db_version(NEETBOX_VERSION) if NEETBOX_VERSION != _db_file_version: - logger.warn( - f"History file version not match: reading from version {_db_file_version} with neetbox version {NEETBOX_VERSION}" - ) + if get_global_config("bypass-db-version-check"): + logger.warn( + f"History file version not match: reading from version {_db_file_version} with neetbox version {NEETBOX_VERSION}" + ) + else: + raise RuntimeError( + f"History file version not match: reading from version {_db_file_version} with neetbox version {NEETBOX_VERSION}. If you want to bypass this check, set 'bypass-db-version-check' to True in global config. This may cause unexpected behavior." + ) cls._path2dbc[path] = new_dbc manager.current[project_id] = new_dbc new_dbc.project_id = project_id @@ -425,7 +430,11 @@ def read_blob( def load_db_of_path(cls, path): if not os.path.isfile(path): raise RuntimeError(f"{path} is not a file") - conn = ProjectDB(path=path) + try: + conn = ProjectDB(path=path) + except Exception as e: + logger.err(f"failed to load db from {path} cause {e}") + return None return conn @classmethod diff --git a/neetbox/utils/_registry.py b/neetbox/utils/_registry.py index f48195ce..8de7bc26 100644 --- a/neetbox/utils/_registry.py +++ b/neetbox/utils/_registry.py @@ -12,11 +12,11 @@ class _RegEndpoint: def __init__(self, what, tags=None): - """Generate a massive type which contains both the regietered object and it's tags + """Generate a massive type which contains both the registered object and its tags Args: - what (_type_): The object being registered - tags (_type_, optional): The tags. Defaults to None. + what (Any): The object being registered + tags (Optional[Union[str, Sequence[str]]], optional): The tags. Defaults to None. """ self.what = what self.tags = tags @@ -26,8 +26,8 @@ def __str__(self) -> str: def _tags_match(search_tags, in_tags) -> bool: - # check if all tags in f_tags are listed in s_tags - if type(search_tags) is not list: + """Check if all tags in search_tags are listed in in_tags.""" + if not isinstance(search_tags, list): search_tags = [search_tags] for _t in search_tags: if _t not in in_tags: @@ -36,14 +36,13 @@ def _tags_match(search_tags, in_tags) -> bool: class Registry(dict): - """Register Helper Class - A Register is a 'dict[str:any]' - Registers are stored in a pool of type dict[str:Register] + A Registry is a 'dict[str:any]' + Registries are stored in a pool of type dict[str:Registry] """ - # class level - _registry_pool: Dict[str, "Registry"] = dict() # all registeres are stored here + # Class-level registry pool + _registry_pool: Dict[str, "Registry"] = dict() def __new__(cls, name: str) -> "Registry": if name in cls._registry_pool: @@ -61,23 +60,20 @@ def _register( self, what: Any, name: Optional[str] = None, - overwrite: Union[bool, Callable] = lambda x: x + f"_{uuid4()}", + overwrite: Union[bool, Callable[[str], str]] = lambda x: x + f"_{uuid4()}", tags: Optional[Union[str, Sequence[str]]] = None, ): - # if not (inspect.isfunction(what) or inspect.isclass(what)): - # logger.warn(f"Registering {type(what)}, which is not a class or a callable.") name = name or what.__name__ - if type(tags) is str: + if isinstance(tags, str): tags = [tags] _endp = _RegEndpoint(what, tags) if name in self.keys(): - if isinstance(overwrite, Callable): + if callable(overwrite): name = overwrite(name) - elif overwrite == True: + elif overwrite is True: pass else: - raise RuntimeError(f"Unknown overwrite type.") - + raise RuntimeError("Unknown overwrite type.") self[name] = _endp else: self[name] = _endp @@ -87,10 +83,13 @@ def register( self, *, name: Optional[str] = None, - overwrite: Union[bool, Callable] = lambda x: x + f"_{uuid4()}", + overwrite: Union[bool, Callable[[str], str]] = lambda x: x + f"_{uuid4()}", tags: Optional[Union[str, Sequence[str]]] = None, ): - return functools.partial(self._register, name=name, overwrite=overwrite, tags=tags) + def decorator(what): + self._register(what, name=name, overwrite=overwrite, tags=tags) + return what + return decorator @classmethod def find( @@ -99,13 +98,9 @@ def find( tags: Optional[Union[str, List[str]]] = None, ): if not name and not tags: - # logger.err( - # ValueError("Please provide at least the name or the tags you want to find."), - # reraise=True, - # ) - pass + raise ValueError("Please provide at least the name or the tags you want to find.") results = [] - # filter name + # Filter by name for reg_name, reg in cls._registry_pool.items(): private_sign = "__" if not reg_name.startswith(private_sign): @@ -113,37 +108,41 @@ def find( results += [(_n, _o) for _n, _o in reg.items(_real_type=False)] elif name in reg: results.append((name, reg[name])) - - # filter tags - if type(tags) is not list: + # Filter by tags + if not isinstance(tags, list): tags = [tags] - - results = {_name: _endp.what for _name, _endp in results if _tags_match(tags, _endp.tags)} + results = { + _name: _endp.what + for _name, _endp in results + if _tags_match(tags, _endp.tags) + } return results def filter(self, tags: Optional[Union[str, Sequence[str]]] = None): results = { - _name: _endp.what for _name, _endp in self._items() if _tags_match(tags, _endp.tags) + _name: _endp.what + for _name, _endp in self._items() + if _tags_match(tags, _endp.tags) } return results def __getitem__(self, __key: str) -> Any: _v = self.__dict__[__key] - if type(_v) is _RegEndpoint: + if isinstance(_v, _RegEndpoint): _v = _v.what return _v def get(self, key: str, **kwargs): if key in self.__dict__: _v = self.__dict__[key] - if type(_v) is _RegEndpoint: + if isinstance(_v, _RegEndpoint): _v = _v.what return _v else: if "default" in kwargs: return kwargs["default"] else: - raise RuntimeError(f"key {key} not found") + raise RuntimeError(f"Key '{key}' not found") def __setitem__(self, k, v) -> None: self.__dict__[k] = v @@ -161,29 +160,50 @@ def update(self, *args, **kwargs): return self.__dict__.update(*args, **kwargs) def keys(self): - return [_item[0] for _item in self.__dict__.items() if type(_item[1]) is _RegEndpoint] + return [ + _item[0] + for _item in self.__dict__.items() + if isinstance(_item[1], _RegEndpoint) + ] def values(self): - return [_item[1].what for _item in self.__dict__.items() if type(_item[1]) is _RegEndpoint] + return [ + _item[1].what + for _item in self.__dict__.items() + if isinstance(_item[1], _RegEndpoint) + ] def items(self, _real_type=True): - _legal_items = [_item for _item in self.__dict__.items() if type(_item[1]) is _RegEndpoint] + _legal_items = [ + _item + for _item in self.__dict__.items() + if isinstance(_item[1], _RegEndpoint) + ] if _real_type: - _legal_items = [(_k, _v.what) for _k, _v in _legal_items if type(_v) is _RegEndpoint] + _legal_items = [ + (_k, _v.what) + for _k, _v in _legal_items + if isinstance(_v, _RegEndpoint) + ] return _legal_items def _items(self, _real_type=True): - _legal_items = [_item for _item in self.__dict__.items() if type(_item[1]) is _RegEndpoint] + _legal_items = [ + _item + for _item in self.__dict__.items() + if isinstance(_item[1], _RegEndpoint) + ] if _real_type: - _legal_items = [(_k, _v) for _k, _v in _legal_items if type(_v) is _RegEndpoint] + _legal_items = [ + (_k, _v) + for _k, _v in _legal_items + if isinstance(_v, _RegEndpoint) + ] return _legal_items def pop(self, *args): return self.__dict__.pop(*args) - def __cmp__(self, dict_): - return self.__cmp__(self.__dict__, dict_) - def __contains__(self, item): return item in self.__dict__ @@ -200,4 +220,4 @@ def __str__(self) -> str: default=str, ) - __repr__ = __str__ + __repr__ = __str__ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index dc9c7217..28f23d9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "neetbox" -version = "0.4.12" +version = "0.4.13" description = "Logging/Debugging/Tracing/Managing/Facilitating long running python projects, especially a replacement of tensorboard for deep learning projects" license = "MIT" authors = ["VisualDust ", "Lideming "]