diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 5d57dcaef..54b8e36b3 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -78,6 +78,8 @@ def __init__(self): """ self._alias = {} self._connected_alias = {} + self._connection_references = {} + self._con_lock = threading.RLock() self._env_uri = None if Config.MILVUS_URI != "": @@ -199,8 +201,13 @@ def disconnect(self, alias: str): if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - if alias in self._connected_alias: - self._connected_alias.pop(alias).close() + with self._con_lock: + if alias in self._connected_alias: + gh = self._connected_alias.pop(alias) + self._connection_references[id(gh)] -= 1 + if self._connection_references[id(gh)] <= 0: + gh.close() + del self._connection_references[id(gh)] def remove_connection(self, alias: str): """ Removes connection from the registry. @@ -263,17 +270,34 @@ def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", **kwargs >>> connections.connect("test", host="localhost", port="19530") """ def connect_milvus(**kwargs): - gh = GrpcHandler(**kwargs) - - t = kwargs.get("timeout") - timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT - - gh._wait_for_channel_ready(timeout=timeout) - kwargs.pop('password') - kwargs.pop('secure', None) - - self._connected_alias[alias] = gh - self._alias[alias] = copy.deepcopy(kwargs) + with self._con_lock: + gh = None + for key, connection_details in self._alias.items(): + + if ( + connection_details["address"] == kwargs["address"] + and connection_details["user"] == kwargs["user"] + and key in self._connected_alias + ): + gh = self._connected_alias[key] + break + + if gh is None: + gh = GrpcHandler(**kwargs) + t = kwargs.get("timeout") + timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT + gh._wait_for_channel_ready(timeout=timeout) + + kwargs.pop('password', None) + kwargs.pop('secure', None) + + self._connected_alias[alias] = gh + self._alias[alias] = copy.deepcopy(kwargs) + + if id(gh) not in self._connection_references: + self._connection_references[id(gh)] = 1 + else: + self._connection_references[id(gh)] += 1 def with_config(config: Tuple) -> bool: for c in config: @@ -293,11 +317,8 @@ def with_config(config: Tuple) -> bool: ) # Make sure passed in None doesnt break - user = user or "" - password = password or "" - # Make sure passed in are Strings - user = str(user) - password = str(password) + user = '' if user is None else str(user) + password = '' if password is None else str(password) # 1st Priority: connection from params if with_config(config): @@ -313,36 +334,32 @@ def with_config(config: Tuple) -> bool: user = parsed_uri.username if parsed_uri.username is not None else user password = parsed_uri.password if parsed_uri.password is not None else password - # Set secure=True if username and password are provided - if len(user) > 0 and len(password) > 0: - kwargs["secure"] = True - - connect_milvus(**kwargs, user=user, password=password) - return # 2nd Priority, connection configs from env - if self._env_uri is not None: + elif self._env_uri is not None: addr, parsed_uri = self._env_uri kwargs["address"] = addr user = parsed_uri.username if parsed_uri.username is not None else "" password = parsed_uri.password if parsed_uri.password is not None else "" - # Set secure=True if uri provided user and password - if len(user) > 0 and len(password) > 0: - kwargs["secure"] = True - connect_milvus(**kwargs, user=user, password=password) - return # 3rd Priority, connect to cached configs with provided user and password - if alias in self._alias: - connect_alias = dict(self._alias[alias].items()) - connect_alias["user"] = user - connect_milvus(**connect_alias, password=password, **kwargs) - return + elif alias in self._alias: + kwargs = dict(self._alias[alias].items()) + # If user is passed in, use it, if not, use previous connections user. + prev_user = kwargs.pop("user") + user = user if user != "" else prev_user # No params, env, and cached configs for the alias - raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) + else: + raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) + + # Set secure=True if username and password are provided + if len(user) > 0 and len(password) > 0: + kwargs["secure"] = True + + connect_milvus(**kwargs, user=user, password=password) def list_connections(self) -> list: @@ -357,7 +374,8 @@ def list_connections(self) -> list: >>> connections.list_connections() // TODO [('default', None), ('test', )] """ - return [(k, self._connected_alias.get(k, None)) for k in self._alias] + with self._con_lock: + return [(k, self._connected_alias.get(k, None)) for k in self._alias] def get_connection_addr(self, alias: str): """ @@ -402,7 +420,8 @@ def has_connection(self, alias: str) -> bool: """ if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - return alias in self._connected_alias + with self._con_lock: + return alias in self._connected_alias def _fetch_handler(self, alias=Config.MILVUS_CONN_ALIAS) -> GrpcHandler: """ Retrieves a GrpcHandler by alias. """