diff --git a/wdapy/usbmux/usbmux.py b/wdapy/usbmux/usbmux.py index 14a4cb0..7013e33 100644 --- a/wdapy/usbmux/usbmux.py +++ b/wdapy/usbmux/usbmux.py @@ -22,6 +22,14 @@ class MuxConnectError(MuxError): logger = logging.getLogger("wdapy.usbmux") +def connect_with_timeout(family, addr, timeout: float): + sock = socket.socket(family, socket.SOCK_STREAM) + sock.settimeout(timeout) + sock.connect(addr) + sock.settimeout(socket.getdefaulttimeout()) + return sock + + class SafeStreamSocket(): def __init__(self, addr: Union[str, tuple, socket.socket]): """ @@ -39,8 +47,7 @@ def __init__(self, addr: Union[str, tuple, socket.socket]): else: family = socket.AF_INET - self._sock = socket.socket(family, socket.SOCK_STREAM) - self._sock.connect(addr) + self._sock = connect_with_timeout(family, addr, timeout=5) def recvall(self, size: int) -> bytearray: buf = bytearray()