diff --git a/bittensor_wallet/keyfile.py b/bittensor_wallet/keyfile.py index a53d429..e68ede7 100644 --- a/bittensor_wallet/keyfile.py +++ b/bittensor_wallet/keyfile.py @@ -303,54 +303,64 @@ def decrypt_keyfile_data( if coldkey_name is not None and password is None: password = get_coldkey_password_from_environment(coldkey_name) - try: - password = ( - getpass.getpass("Enter password to unlock key: ") - if password is None - else password - ) - with console.status(":key: Decrypting key..."): - # NaCl SecretBox decrypt. - if keyfile_data_is_encrypted_nacl(keyfile_data): - password = bytes(password, "utf-8") - kdf = pwhash.argon2i.kdf - key = kdf( - secret.SecretBox.KEY_SIZE, - password, - NACL_SALT, - opslimit=pwhash.argon2i.OPSLIMIT_SENSITIVE, - memlimit=pwhash.argon2i.MEMLIMIT_SENSITIVE, - ) - box = secret.SecretBox(key) - decrypted_keyfile_data = box.decrypt(keyfile_data[len("$NACL") :]) - # Ansible decrypt. - elif keyfile_data_is_encrypted_ansible(keyfile_data): - vault = Vault(password) - try: + finished_password_input = False + while not finished_password_input: + try: + password_input: str = ( + getpass.getpass("Enter password to unlock key: ") + if password is None + else password + ) + with console.status(":key: Decrypting key..."): + # NaCl SecretBox decrypt. + if keyfile_data_is_encrypted_nacl(keyfile_data): + password = bytes(password_input, "utf-8") + kdf = pwhash.argon2i.kdf + key = kdf( + secret.SecretBox.KEY_SIZE, + password, + NACL_SALT, + opslimit=pwhash.argon2i.OPSLIMIT_SENSITIVE, + memlimit=pwhash.argon2i.MEMLIMIT_SENSITIVE, + ) + box = secret.SecretBox(key) + decrypted_keyfile_data = box.decrypt(keyfile_data[len("$NACL") :]) + finished_password_input = True + + # Ansible decrypt. + elif keyfile_data_is_encrypted_ansible(keyfile_data): + vault = Vault(password) decrypted_keyfile_data = vault.load(keyfile_data) - except AnsibleVaultError: - raise KeyFileError("Invalid password") - # Legacy decrypt. - elif keyfile_data_is_encrypted_legacy(keyfile_data): - __SALT = ( - b"Iguesscyborgslikemyselfhaveatendencytobeparanoidaboutourorigins" - ) - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - salt=__SALT, - length=32, - iterations=10000000, - backend=default_backend(), - ) - key = base64.urlsafe_b64encode(kdf.derive(password.encode())) - cipher_suite = Fernet(key) - decrypted_keyfile_data = cipher_suite.decrypt(keyfile_data) - # Unknown. + finished_password_input = True + + # Legacy decrypt. + elif keyfile_data_is_encrypted_legacy(keyfile_data): + __SALT = b"Iguesscyborgslikemyselfhaveatendencytobeparanoidaboutourorigins" + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + salt=__SALT, + length=32, + iterations=10000000, + backend=default_backend(), + ) + key = base64.urlsafe_b64encode(kdf.derive(password.encode())) + cipher_suite = Fernet(key) + decrypted_keyfile_data = cipher_suite.decrypt(keyfile_data) + finished_password_input = True + # Unknown. + else: + raise KeyFileError(f"keyfile data: {str(keyfile_data)} is corrupt") + + except (InvalidSignature, InvalidKey, InvalidToken, AnsibleVaultError): + console.print("Wrong password, try again") + password = None + + except Exception as e: + if "Decryption failed. Ciphertext failed verification" in str(e): + console.print("Wrong password, try again") + password = None else: - raise KeyFileError(f"keyfile data: {keyfile_data.decode()} is corrupt.") - - except (InvalidSignature, InvalidKey, InvalidToken): - raise KeyFileError("Invalid password") + raise if not isinstance(decrypted_keyfile_data, bytes): decrypted_keyfile_data = json.dumps(decrypted_keyfile_data).encode()