Skip to content

Commit

Permalink
ruff format with line length 120
Browse files Browse the repository at this point in the history
  • Loading branch information
Olivier committed Oct 26, 2024
1 parent 45861d7 commit 4575051
Show file tree
Hide file tree
Showing 81 changed files with 10,066 additions and 3,042 deletions.
100 changes: 79 additions & 21 deletions asyncua/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,19 @@ def server_url(self) -> ParseResult:
return url

@staticmethod
def find_endpoint(endpoints: Iterable[ua.EndpointDescription], security_mode: ua.MessageSecurityMode, policy_uri: str) -> ua.EndpointDescription:
def find_endpoint(
endpoints: Iterable[ua.EndpointDescription], security_mode: ua.MessageSecurityMode, policy_uri: str
) -> ua.EndpointDescription:
"""
Find endpoint with required security mode and policy URI
"""
_logger.info("find_endpoint %r %r %r", endpoints, security_mode, policy_uri)
for ep in endpoints:
if ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME) and ep.SecurityMode == security_mode and ep.SecurityPolicyUri == policy_uri:
if (
ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME)
and ep.SecurityMode == security_mode
and ep.SecurityPolicyUri == policy_uri
):
return ep
raise ua.UaError(f"No matching endpoints: {security_mode}, {policy_uri}")

Expand Down Expand Up @@ -178,7 +184,9 @@ async def set_security_string(self, string: str) -> None:

policy_class = getattr(security_policies, f"SecurityPolicy{parts[0]}")
mode = getattr(ua.MessageSecurityMode, parts[1])
return await self.set_security(policy_class, parts[2], parts[3], client_key_password, parts[4] if len(parts) >= 5 else None, mode)
return await self.set_security(
policy_class, parts[2], parts[3], client_key_password, parts[4] if len(parts) >= 5 else None, mode
)

async def set_security(
self,
Expand Down Expand Up @@ -242,7 +250,9 @@ async def load_client_certificate(self, path: str, extension: Optional[str] = No
"""
self.user_certificate = await uacrypto.load_certificate(path, extension)

async def load_private_key(self, path: Path, password: Optional[Union[str, bytes]] = None, extension: Optional[str] = None) -> None:
async def load_private_key(
self, path: Path, password: Optional[Union[str, bytes]] = None, extension: Optional[str] = None
) -> None:
"""
Load user private key. This is used for authenticating using certificate
"""
Expand Down Expand Up @@ -309,7 +319,9 @@ async def connect(self) -> None:
try:
await self.create_session()
try:
await self.activate_session(username=self._username, password=self._password, certificate=self.user_certificate)
await self.activate_session(
username=self._username, password=self._password, certificate=self.user_certificate
)
except Exception:
# clean up session
await self.close_session()
Expand Down Expand Up @@ -395,7 +407,11 @@ async def open_secure_channel(self, renew: bool = False) -> None:
params.ClientNonce = create_nonce(self.security_policy.secure_channel_nonce_length)
result = await self.uaclient.open_secure_channel(params)
if self.secure_channel_timeout != result.SecurityToken.RevisedLifetime:
_logger.info("Requested secure channel timeout to be %dms, got %dms instead", self.secure_channel_timeout, result.SecurityToken.RevisedLifetime)
_logger.info(
"Requested secure channel timeout to be %dms, got %dms instead",
self.secure_channel_timeout,
result.SecurityToken.RevisedLifetime,
)
self.secure_channel_timeout = result.SecurityToken.RevisedLifetime

async def close_secure_channel(self):
Expand All @@ -408,7 +424,9 @@ async def get_endpoints(self) -> List[ua.EndpointDescription]:
params.EndpointUrl = self.server_url.geturl()
return await self.uaclient.get_endpoints(params)

async def register_server(self, server: "asyncua.server.Server", discovery_configuration: Optional[ua.DiscoveryConfiguration] = None) -> None:
async def register_server(
self, server: "asyncua.server.Server", discovery_configuration: Optional[ua.DiscoveryConfiguration] = None
) -> None:
"""
register a server to discovery server
if discovery_configuration is provided, the newer register_server2 service call is used
Expand All @@ -427,7 +445,9 @@ async def register_server(self, server: "asyncua.server.Server", discovery_confi
return await self.uaclient.register_server2(params)
return await self.uaclient.register_server(serv)

async def unregister_server(self, server: "asyncua.server.Server", discovery_configuration: Optional[ua.DiscoveryConfiguration] = None) -> None:
async def unregister_server(
self, server: "asyncua.server.Server", discovery_configuration: Optional[ua.DiscoveryConfiguration] = None
) -> None:
"""
register a server to discovery server
if discovery_configuration is provided, the newer register_server2 service call is used
Expand Down Expand Up @@ -520,7 +540,11 @@ async def create_session(self) -> ua.CreateSessionResult:
self._policy_ids = ep.UserIdentityTokens
# Actual maximum number of milliseconds that a Session shall remain open without activity
if self.session_timeout != response.RevisedSessionTimeout:
_logger.warning("Requested session timeout to be %dms, got %dms instead", self.secure_channel_timeout, response.RevisedSessionTimeout)
_logger.warning(
"Requested session timeout to be %dms, got %dms instead",
self.secure_channel_timeout,
response.RevisedSessionTimeout,
)
self.session_timeout = response.RevisedSessionTimeout
self._renew_channel_task = asyncio.create_task(self._renew_channel_loop())
self._monitor_server_task = asyncio.create_task(self._monitor_server_loop())
Expand Down Expand Up @@ -615,7 +639,12 @@ def server_policy_uri(self, token_type: ua.UserTokenType) -> str:
return self.security_policy.URI
return self.security_policy.URI

async def activate_session(self, username: Optional[str] = None, password: Optional[str] = None, certificate: Optional[x509.Certificate] = None) -> ua.ActivateSessionResult:
async def activate_session(
self,
username: Optional[str] = None,
password: Optional[str] = None,
certificate: Optional[x509.Certificate] = None,
) -> ua.ActivateSessionResult:
"""
Activate session using either username and password or private_key
"""
Expand Down Expand Up @@ -652,12 +681,16 @@ def _add_certificate_auth(self, params, certificate, challenge):
params.UserTokenSignature = ua.SignatureData()
# use signature algorithm that was used for certificate generation
if certificate.signature_hash_algorithm.name == "sha256":
params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.Certificate, "certificate_basic256sha256")
params.UserIdentityToken.PolicyId = self.server_policy_id(
ua.UserTokenType.Certificate, "certificate_basic256sha256"
)
sig = uacrypto.sign_sha256(self.user_private_key, challenge)
params.UserTokenSignature.Algorithm = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"
params.UserTokenSignature.Signature = sig
else:
params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.Certificate, "certificate_basic256")
params.UserIdentityToken.PolicyId = self.server_policy_id(
ua.UserTokenType.Certificate, "certificate_basic256"
)
sig = uacrypto.sign_sha1(self.user_private_key, challenge)
params.UserTokenSignature.Algorithm = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
params.UserTokenSignature.Signature = sig
Expand Down Expand Up @@ -728,7 +761,12 @@ def get_node(self, nodeid: Union[Node, ua.NodeId, str, int]) -> Node:
"""
return Node(self.uaclient, nodeid)

async def create_subscription(self, period: Union[ua.CreateSubscriptionParameters, float], handler: SubscriptionHandler, publishing: bool = True) -> Subscription:
async def create_subscription(
self,
period: Union[ua.CreateSubscriptionParameters, float],
handler: SubscriptionHandler,
publishing: bool = True,
) -> Subscription:
"""
Create a subscription.
Returns a Subscription object which allows to subscribe to events or data changes on server.
Expand Down Expand Up @@ -760,14 +798,24 @@ def get_subscription_revised_params(
params: ua.CreateSubscriptionParameters,
results: ua.CreateSubscriptionResult,
) -> Optional[ua.ModifySubscriptionParameters]:
if results.RevisedPublishingInterval == params.RequestedPublishingInterval and results.RevisedLifetimeCount == params.RequestedLifetimeCount and results.RevisedMaxKeepAliveCount == params.RequestedMaxKeepAliveCount:
if (
results.RevisedPublishingInterval == params.RequestedPublishingInterval
and results.RevisedLifetimeCount == params.RequestedLifetimeCount
and results.RevisedMaxKeepAliveCount == params.RequestedMaxKeepAliveCount
):
return None
_logger.warning("Revised values returned differ from subscription values: %s", results)
revised_interval = results.RevisedPublishingInterval
# Adjust the MaxKeepAliveCount based on the RevisedPublishInterval when necessary
new_keepalive_count = self.get_keepalive_count(revised_interval)
if revised_interval != params.RequestedPublishingInterval and new_keepalive_count != params.RequestedMaxKeepAliveCount:
_logger.info("KeepAliveCount will be updated to %s " "for consistency with RevisedPublishInterval", new_keepalive_count)
if (
revised_interval != params.RequestedPublishingInterval
and new_keepalive_count != params.RequestedMaxKeepAliveCount
):
_logger.info(
"KeepAliveCount will be updated to %s " "for consistency with RevisedPublishInterval",
new_keepalive_count,
)
modified_params = ua.ModifySubscriptionParameters()
# copy the existing subscription parameters
copy_dataclass_attr(params, modified_params)
Expand Down Expand Up @@ -811,7 +859,9 @@ async def get_namespace_index(self, uri: str) -> int:
async def delete_nodes(self, nodes: Iterable[Node], recursive=False) -> Tuple[List[Node], List[ua.StatusCode]]:
return await delete_nodes(self.uaclient, nodes, recursive)

async def import_xml(self, path=None, xmlstring=None, strict_mode=True, auto_load_definitions: bool = True) -> List[ua.NodeId]:
async def import_xml(
self, path=None, xmlstring=None, strict_mode=True, auto_load_definitions: bool = True
) -> List[ua.NodeId]:
"""
Import nodes defined in xml
"""
Expand Down Expand Up @@ -850,7 +900,9 @@ async def load_type_definitions(self, nodes=None):
_logger.warning("Deprecated since spec 1.04, call load_data_type_definitions")
return await load_type_definitions(self, nodes)

async def load_data_type_definitions(self, node: Optional[Node] = None, overwrite_existing: bool = False) -> Dict[str, Type]:
async def load_data_type_definitions(
self, node: Optional[Node] = None, overwrite_existing: bool = False
) -> Dict[str, Type]:
"""
Load custom types (custom structures/extension objects) definition from server
Generate Python classes for custom structures/extension objects defined in server
Expand Down Expand Up @@ -891,7 +943,9 @@ async def unregister_nodes(self, nodes: Iterable[Node]) -> None:
node.nodeid = node.basenodeid
node.basenodeid = None

async def read_attributes(self, nodes: Iterable[Node], attr: ua.AttributeIds = ua.AttributeIds.Value) -> List[ua.DataValue]:
async def read_attributes(
self, nodes: Iterable[Node], attr: ua.AttributeIds = ua.AttributeIds.Value
) -> List[ua.DataValue]:
"""
Read the attributes of multiple nodes.
"""
Expand All @@ -905,7 +959,9 @@ async def read_values(self, nodes: Iterable[Node]) -> List[Any]:
res = await self.read_attributes(nodes, attr=ua.AttributeIds.Value)
return [r.Value.Value if r.Value else None for r in res]

async def write_values(self, nodes: Iterable[Node], values: Iterable[Any], raise_on_partial_error: bool = True) -> List[ua.StatusCode]:
async def write_values(
self, nodes: Iterable[Node], values: Iterable[Any], raise_on_partial_error: bool = True
) -> List[ua.StatusCode]:
"""
Write values to multiple nodes in one ua call
"""
Expand Down Expand Up @@ -938,7 +994,9 @@ async def browse_nodes(self, nodes: Iterable[Node]) -> List[Tuple[Node, ua.Brows
results = await self.uaclient.browse(parameters)
return list(zip(nodes, results))

async def translate_browsepaths(self, starting_node: ua.NodeId, relative_paths: Iterable[Union[ua.RelativePath, str]]) -> List[ua.BrowsePathResult]:
async def translate_browsepaths(
self, starting_node: ua.NodeId, relative_paths: Iterable[Union[ua.RelativePath, str]]
) -> List[ua.BrowsePathResult]:
bpaths = []
for p in relative_paths:
try:
Expand Down
15 changes: 12 additions & 3 deletions asyncua/client/ha/ha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ async def start(self) -> None:
self.is_running = True

async def stop(self):
to_stop: Sequence[Union[KeepAlive, HaManager, Reconciliator]] = chain(self._keepalive_task, self._manager_task, self._reconciliator_task)
to_stop: Sequence[Union[KeepAlive, HaManager, Reconciliator]] = chain(
self._keepalive_task, self._manager_task, self._reconciliator_task
)
stop = [p.stop() for p in to_stop]

await asyncio.gather(*stop)
Expand Down Expand Up @@ -231,7 +233,9 @@ async def subscribe_data_change(
for url in self.urls:
vs = self.ideal_map[url].get(sub_name)
if not vs:
_logger.warning("The subscription specified for the data_change: %s doesn't exist in ideal_map", sub_name)
_logger.warning(
"The subscription specified for the data_change: %s doesn't exist in ideal_map", sub_name
)
return
vs.subscribe_data_change(nodes, attr, queuesize)
await self.hook_on_subscribe(nodes=nodes, attr=attr, queuesize=queuesize, url=url)
Expand Down Expand Up @@ -513,7 +517,12 @@ async def reconnect_warm(self) -> None:
healthy, unhealthy = await self.ha_client.group_clients_by_health()

async def reco_resub(client: Client, force: bool):
if force or not client.uaclient.protocol or client.uaclient.protocol and client.uaclient.protocol.state == UASocketProtocol.CLOSED:
if (
force
or not client.uaclient.protocol
or client.uaclient.protocol
and client.uaclient.protocol.state == UASocketProtocol.CLOSED
):
_logger.info("Virtually reconnecting and resubscribing %s", client)
await self.ha_client.reconnect(client=client)

Expand Down
8 changes: 6 additions & 2 deletions asyncua/client/ha/reconciliator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ async def update_nodes(self, real_map: SubMap, ideal_map: SubMap, targets: Set[s
real_sub = self.name_to_subscription[url].get(sub_name)
# in case the previous create_subscription request failed
if not real_sub:
_logger.warning("Can't create nodes for %s since underlying " "subscription for %s doesn't exist", url, sub_name)
_logger.warning(
"Can't create nodes for %s since underlying " "subscription for %s doesn't exist", url, sub_name
)
continue
vs_real = real_map[url][sub_name]
vs_ideal = ideal_map[url][sub_name]
Expand Down Expand Up @@ -310,7 +312,9 @@ async def update_subscription_modes(self, real_map: SubMap, ideal_map: SubMap, t
real_sub = self.name_to_subscription[url].get(sub_name)
# in case the previous create_subscription request failed
if not real_sub:
_logger.warning("Can't change modes for %s since underlying subscription for %s doesn't exist", url, sub_name)
_logger.warning(
"Can't change modes for %s since underlying subscription for %s doesn't exist", url, sub_name
)
continue
vs_real = real_map[url][sub_name]
vs_ideal = ideal_map[url][sub_name]
Expand Down
Loading

0 comments on commit 4575051

Please sign in to comment.