diff --git a/Cargo.toml b/Cargo.toml index ff1acbf2b..ed8dd7aad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,7 @@ sui-gateway = { version = "^1.0.0", path = "packages/sui-gateway" } sui-types = { version = "^1.0.0", path = "packages/sui-types" } starknet-checked-felt = { version = "^1.0.0", path = "packages/starknet-checked-felt" } starknet-types-core = { version = "0.1.7" } +starknet-types = { version = "^1.0.0", path = "packages/starknet-types" } starknet-core = "0.12.0" starknet-providers = "0.12.0" syn = "2.0.92" diff --git a/ampd/Cargo.toml b/ampd/Cargo.toml index f5ab7601e..202fa78de 100644 --- a/ampd/Cargo.toml +++ b/ampd/Cargo.toml @@ -24,7 +24,7 @@ error-stack = { workspace = true } ethers-contract = { workspace = true } ethers-core = { workspace = true } ethers-providers = { version = "2.0.13", default-features = false, features = [ - "rustls", + "rustls", ] } events = { workspace = true } events-derive = { workspace = true } @@ -39,7 +39,9 @@ move-core-types = { git = "https://github.com/mystenlabs/sui", tag = "testnet-v1 multisig = { workspace = true, features = ["library"] } multiversx-sdk = "0.6.1" num-traits = { workspace = true } -openssl = { version = "0.10.35", features = ["vendored"] } # Needed to make arm compilation work by forcing vendoring +openssl = { version = "0.10.35", features = [ + "vendored", +] } # Needed to make arm compilation work by forcing vendoring prost = "0.11.9" prost-types = "0.11.9" report = { workspace = true } @@ -58,7 +60,7 @@ sui-gateway = { workspace = true } sui-json-rpc-types = { git = "https://github.com/mystenlabs/sui", tag = "testnet-v1.39.1" } sui-types = { git = "https://github.com/mystenlabs/sui", tag = "testnet-v1.39.1" } tendermint = "0.35.0" -tendermint-rpc = { version = "0.35.0", features = [ "http-client" ] } +tendermint-rpc = { version = "0.35.0", features = ["http-client"] } thiserror = { workspace = true } tokio = { workspace = true, features = ["signal"] } tokio-stream = { workspace = true, features = ["sync"] } @@ -67,13 +69,22 @@ toml = "0.5.9" tonic = "0.9.2" tracing = { version = "0.1.37", features = ["valuable", "log"] } tracing-core = { version = "0.1.30", features = ["valuable"] } -tracing-subscriber = { version = "0.3.16", features = ["json", "valuable", "env-filter"] } +tracing-subscriber = { version = "0.3.16", features = [ + "json", + "valuable", + "env-filter", +] } typed-builder = "0.18.2" url = "2.3.1" valuable = { version = "0.1.0", features = ["derive"] } valuable-serde = { version = "0.1.0", features = ["std"] } voting-verifier = { workspace = true } +starknet-core = { workspace = true } +starknet-providers = { workspace = true } +starknet-types = { workspace = true } +starknet-checked-felt = { workspace = true } + [dev-dependencies] assert_ok = { workspace = true } ed25519-dalek = { workspace = true, features = ["rand_core"] } diff --git a/ampd/src/handlers/config.rs b/ampd/src/handlers/config.rs index c7bb9da9b..36018ac9a 100644 --- a/ampd/src/handlers/config.rs +++ b/ampd/src/handlers/config.rs @@ -63,6 +63,26 @@ pub enum Config { cosmwasm_contract: TMAddress, rpc_url: Url, }, + StarknetMsgVerifier { + cosmwasm_contract: TMAddress, + rpc_url: Url, + }, +} + +fn validate_starknet_msg_verifier_config<'de, D>(configs: &[Config]) -> Result<(), D::Error> +where + D: Deserializer<'de>, +{ + match configs + .iter() + .filter(|config| matches!(config, Config::StarknetMsgVerifier { .. })) + .count() + { + count if count > 1 => Err(de::Error::custom( + "only one Starknet msg verifier config is allowed", + )), + _ => Ok(()), + } } fn validate_evm_verifier_set_verifier_configs<'de, D>(configs: &[Config]) -> Result<(), D::Error> @@ -133,6 +153,7 @@ where { let configs: Vec = Deserialize::deserialize(deserializer)?; + validate_starknet_msg_verifier_config::(&configs)?; validate_evm_msg_verifier_configs::(&configs)?; validate_evm_verifier_set_verifier_configs::(&configs)?; diff --git a/ampd/src/handlers/mod.rs b/ampd/src/handlers/mod.rs index 1f8868164..ce724a7e5 100644 --- a/ampd/src/handlers/mod.rs +++ b/ampd/src/handlers/mod.rs @@ -5,6 +5,8 @@ pub mod evm_verify_verifier_set; pub mod multisig; pub mod mvx_verify_msg; pub mod mvx_verify_verifier_set; +pub mod starknet_verify_msg; +pub mod starknet_verify_verifier_set; pub(crate) mod stellar_verify_msg; pub(crate) mod stellar_verify_verifier_set; pub mod sui_verify_msg; diff --git a/ampd/src/handlers/starknet_verify_msg.rs b/ampd/src/handlers/starknet_verify_msg.rs new file mode 100644 index 000000000..2c451ee22 --- /dev/null +++ b/ampd/src/handlers/starknet_verify_msg.rs @@ -0,0 +1,513 @@ +use std::collections::HashMap; +use std::convert::TryInto; + +use async_trait::async_trait; +use axelar_wasm_std::msg_id::FieldElementAndEventIndex; +use axelar_wasm_std::voting::{PollId, Vote}; +use cosmrs::cosmwasm::MsgExecuteContract; +use cosmrs::tx::Msg; +use cosmrs::Any; +use error_stack::{FutureExt, ResultExt}; +use events::Error::EventTypeMismatch; +use events_derive::try_from; +use futures::future::try_join_all; +use itertools::Itertools; +use router_api::ChainName; +use serde::Deserialize; +use starknet_checked_felt::CheckedFelt; +use starknet_core::types::Felt; +use starknet_types::events::contract_call::ContractCallEvent; +use tokio::sync::watch::Receiver; +use tracing::info; +use voting_verifier::msg::ExecuteMsg; + +use crate::event_processor::EventHandler; +use crate::handlers::errors::Error; +use crate::handlers::errors::Error::DeserializeEvent; +use crate::starknet::json_rpc::StarknetClient; +use crate::starknet::verifier::verify_msg; +use crate::types::{Hash, TMAddress}; + +type Result = error_stack::Result; + +#[derive(Deserialize, Debug)] +pub struct Message { + pub message_id: FieldElementAndEventIndex, + pub destination_address: String, + pub destination_chain: ChainName, + pub source_address: CheckedFelt, + pub payload_hash: Hash, +} + +#[derive(Deserialize, Debug)] +#[try_from("wasm-messages_poll_started")] +struct PollStartedEvent { + #[serde(rename = "_contract_address")] + contract_address: TMAddress, + poll_id: PollId, + source_gateway_address: String, + expires_at: u64, + messages: Vec, + participants: Vec, +} + +pub struct Handler +where + C: StarknetClient, +{ + verifier: TMAddress, + voting_verifier: TMAddress, + rpc_client: C, + latest_block_height: Receiver, +} + +impl Handler +where + C: StarknetClient + Send + Sync, +{ + pub fn new( + verifier: TMAddress, + voting_verifier: TMAddress, + rpc_client: C, + latest_block_height: Receiver, + ) -> Self { + Self { + verifier, + voting_verifier, + rpc_client, + latest_block_height, + } + } + + fn vote_msg(&self, poll_id: PollId, votes: Vec) -> MsgExecuteContract { + MsgExecuteContract { + sender: self.verifier.as_ref().clone(), + contract: self.voting_verifier.as_ref().clone(), + msg: serde_json::to_vec(&ExecuteMsg::Vote { poll_id, votes }) + .expect("vote msg should serialize"), + funds: vec![], + } + } +} + +#[async_trait] +impl EventHandler for Handler +where + V: StarknetClient + Send + Sync, +{ + type Err = Error; + + async fn handle(&self, event: &events::Event) -> Result> { + let PollStartedEvent { + poll_id, + source_gateway_address, + messages, + participants, + expires_at, + contract_address, + .. + } = match event.try_into() as error_stack::Result<_, _> { + Err(report) if matches!(report.current_context(), EventTypeMismatch(_)) => { + return Ok(vec![]); + } + event => event.change_context(DeserializeEvent)?, + }; + + if self.voting_verifier != contract_address { + return Ok(vec![]); + } + + if !participants.contains(&self.verifier) { + return Ok(vec![]); + } + + let latest_block_height = *self.latest_block_height.borrow(); + if latest_block_height >= expires_at { + info!(poll_id = poll_id.to_string(), "skipping expired poll"); + return Ok(vec![]); + } + + let unique_msgs = messages + .iter() + .unique_by(|msg| &msg.message_id.tx_hash) + .collect::>(); + + // key is the tx_hash of the tx holding the event + let events: HashMap = + try_join_all(unique_msgs.iter().map(|msg| { + self.rpc_client + .get_event_by_hash_contract_call(msg.message_id.tx_hash.clone()) + })) + .change_context(Error::TxReceipts) + .await? + .into_iter() + .flatten() + .collect(); + + let mut votes = vec![]; + for msg in unique_msgs { + if !events.contains_key(&msg.message_id.tx_hash) { + votes.push(Vote::NotFound); + continue; + } + votes.push(verify_msg( + events.get(&msg.message_id.tx_hash).unwrap(), // safe to unwrap, because of previous check + msg, + &source_gateway_address, + )); + } + + Ok(vec![self + .vote_msg(poll_id, votes) + .into_any() + .expect("vote msg should serialize")]) + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use base64::engine::general_purpose::STANDARD; + use base64::Engine; + use ethers_core::types::H256; + use events::Event; + use mockall::predicate::eq; + use starknet_core::types::Felt; + use tendermint::abci; + use tokio::sync::watch; + use tokio::test as async_test; + use voting_verifier::events::{PollMetadata, PollStarted, TxEventConfirmation}; + + use super::*; + use crate::starknet::json_rpc::MockStarknetClient; + use crate::PREFIX; + + #[async_test] + async fn should_correctly_validate_messages() { + // Setup the context + let voting_verifier = TMAddress::random(PREFIX); + let verifier = TMAddress::random(PREFIX); + let expiration = 100u64; + let (_, rx) = watch::channel(expiration - 1); + + // Prepare the rpc client, which fetches the event and the vote broadcaster + let mut rpc_client = MockStarknetClient::new(); + rpc_client + .expect_get_event_by_hash_contract_call() + .returning(|_| { + Ok(Some(( + Felt::from_str( + "0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e", + ) + .unwrap(), + ContractCallEvent { + from_contract_addr: String::from("source-gw-addr"), + destination_address: String::from("destination-address"), + destination_chain: "ethereum".parse().unwrap(), + source_address: Felt::ONE.into(), + payload_hash: H256::from_slice(&[ + 28u8, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, + 123, 86, 217, 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, + 234, 200, + ]), + }, + ))) + }); + + let event: Event = get_event( + get_poll_started_event_with_two_msgs(participants(5, Some(verifier.clone())), 100_u64), + &voting_verifier, + ); + + let handler = super::Handler::new(verifier, voting_verifier, rpc_client, rx); + let result = handler.handle(&event).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(MsgExecuteContract::from_any(result.first().unwrap()).is_ok()); + } + + #[async_test] + async fn should_skip_duplicate_messages() { + // Setup the context + let voting_verifier = TMAddress::random(PREFIX); + let verifier = TMAddress::random(PREFIX); + let expiration = 100u64; + let (_, rx) = watch::channel(expiration - 1); + + // Prepare the rpc client, which fetches the event and the vote broadcaster + let mut rpc_client = MockStarknetClient::new(); + rpc_client + .expect_get_event_by_hash_contract_call() + .once() + .with(eq(CheckedFelt::from_str( + "0x045410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439f", + ) + .unwrap())) + .returning(|_| { + Ok(Some(( + Felt::from_str( + "0x045410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439f", + ) + .unwrap(), + ContractCallEvent { + from_contract_addr: String::from("source-gw-addr"), + destination_address: String::from("destination-address"), + destination_chain: "ethereum".parse().unwrap(), + source_address: Felt::ONE.into(), + payload_hash: H256::from_slice(&[ + 28u8, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, + 123, 86, 217, 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, + 234, 200, + ]), + }, + ))) + }); + + let event: Event = get_event( + get_poll_started_event_with_duplicate_msgs( + participants(5, Some(verifier.clone())), + 100, + ), + &voting_verifier, + ); + + let handler = super::Handler::new(verifier, voting_verifier, rpc_client, rx); + let result = handler.handle(&event).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(MsgExecuteContract::from_any(result.first().unwrap()).is_ok()); + } + + #[async_test] + async fn should_skip_wrong_verifier_address() { + // Setup the context + let voting_verifier = TMAddress::random(PREFIX); + let verifier = TMAddress::random(PREFIX); + let expiration = 100u64; + let (_, rx) = watch::channel(expiration - 1); + + // Prepare the rpc client, which fetches the event and the vote broadcaster + let mut rpc_client = MockStarknetClient::new(); + rpc_client.expect_get_event_by_hash_contract_call().times(0); + + let event: Event = get_event( + get_poll_started_event_with_duplicate_msgs( + participants(5, Some(verifier.clone())), + 100, + ), + &TMAddress::random(PREFIX), // some other random address + ); + + let handler = super::Handler::new(verifier, voting_verifier, rpc_client, rx); + + let result = handler.handle(&event).await.unwrap(); + assert_eq!(result, vec![]); + } + + #[async_test] + async fn should_skip_non_participating_verifier() { + // Setup the context + let voting_verifier = TMAddress::random(PREFIX); + let verifier = TMAddress::random(PREFIX); + let expiration = 100u64; + let (_, rx) = watch::channel(expiration - 1); + + // Prepare the rpc client, which fetches the event and the vote broadcaster + let mut rpc_client = MockStarknetClient::new(); + rpc_client.expect_get_event_by_hash_contract_call().times(0); + + let event: Event = get_event( + // woker is not in participat set + get_poll_started_event_with_duplicate_msgs(participants(5, None), 100), + &voting_verifier, + ); + + let handler = super::Handler::new(verifier, voting_verifier, rpc_client, rx); + + let result = handler.handle(&event).await.unwrap(); + assert_eq!(result, vec![]); + } + + #[async_test] + async fn should_skip_expired_poll_event() { + // Setup the context + let voting_verifier = TMAddress::random(PREFIX); + let verifier = TMAddress::random(PREFIX); + let expiration = 100u64; + let (_, rx) = watch::channel(expiration); // expired! + + // Prepare the rpc client, which fetches the event and the vote broadcaster + let mut rpc_client = MockStarknetClient::new(); + rpc_client.expect_get_event_by_hash_contract_call().times(0); + + let event: Event = get_event( + get_poll_started_event_with_duplicate_msgs( + participants(5, Some(verifier.clone())), + 100, + ), + &voting_verifier, + ); + + let handler = super::Handler::new(verifier, voting_verifier, rpc_client, rx); + + let result = handler.handle(&event).await.unwrap(); + assert_eq!(result, vec![]); + } + + fn participants(n: u8, verifier: Option) -> Vec { + (0..n) + .map(|_| TMAddress::random(PREFIX)) + .chain(verifier) + .collect() + } + + fn get_event(event: impl Into, contract_address: &TMAddress) -> Event { + let mut event: cosmwasm_std::Event = event.into(); + + event.ty = format!("wasm-{}", event.ty); + event = event.add_attribute("_contract_address", contract_address.to_string()); + + abci::Event::new( + event.ty, + event + .attributes + .into_iter() + .map(|cosmwasm_std::Attribute { key, value }| { + (STANDARD.encode(key), STANDARD.encode(value)) + }), + ) + .try_into() + .unwrap() + } + + fn get_poll_started_event_with_two_msgs( + participants: Vec, + expires_at: u64, + ) -> PollStarted { + PollStarted::Messages { + metadata: PollMetadata { + poll_id: "100".parse().unwrap(), + source_chain: "starknet".parse().unwrap(), + source_gateway_address: "source-gw-addr".parse().unwrap(), + confirmation_height: 15, + expires_at, + participants: participants + .into_iter() + .map(|addr| cosmwasm_std::Addr::unchecked(addr.to_string())) + .collect(), + }, + messages: vec![ + TxEventConfirmation { + tx_id: "0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e" + .parse() + .unwrap(), + message_id: + "0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e-0" + .parse() + .unwrap(), + event_index: 0, + source_address: + "0x0000000000000000000000000000000000000000000000000000000000000001" + .parse() + .unwrap(), + destination_chain: "ethereum".parse().unwrap(), + destination_address: "destination-address".parse().unwrap(), + payload_hash: H256::from_slice(&[ + // keccak256("hello") + 28, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, 86, + 217, 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, 200, + ]) + .into(), + }, + TxEventConfirmation { + tx_id: "0x045410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439f" + .parse() + .unwrap(), + message_id: + "0x045410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439f-1" + .parse() + .unwrap(), + event_index: 1, + source_address: + "0x0000000000000000000000000000000000000000000000000000000000000001" + .parse() + .unwrap(), + destination_chain: "ethereum".parse().unwrap(), + destination_address: "destination-address".parse().unwrap(), + payload_hash: H256::from_slice(&[ + // keccak256("hello") + 28u8, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, + 86, 217, 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, 200, + ]) + .into(), + }, + ], + } + } + + fn get_poll_started_event_with_duplicate_msgs( + participants: Vec, + expires_at: u64, + ) -> PollStarted { + PollStarted::Messages { + metadata: PollMetadata { + poll_id: "100".parse().unwrap(), + source_chain: "starknet".parse().unwrap(), + source_gateway_address: "source-gw-addr".parse().unwrap(), + confirmation_height: 15, + expires_at, + participants: participants + .into_iter() + .map(|addr| cosmwasm_std::Addr::unchecked(addr.to_string())) + .collect(), + }, + messages: vec![ + TxEventConfirmation { + tx_id: "0x045410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439f" + .parse() + .unwrap(), + message_id: + "0x045410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439f-1" + .parse() + .unwrap(), + event_index: 1, + source_address: + "0x0000000000000000000000000000000000000000000000000000000000000001" + .parse() + .unwrap(), + destination_chain: "ethereum".parse().unwrap(), + destination_address: "destination-address".parse().unwrap(), + payload_hash: H256::from_slice(&[ + // keccak256("hello") + 28u8, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, + 86, 217, 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, 200, + ]) + .into(), + }, + TxEventConfirmation { + tx_id: "0x045410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439f" + .parse() + .unwrap(), + message_id: + "0x045410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439f-1" + .parse() + .unwrap(), + event_index: 1, + source_address: + "0x0000000000000000000000000000000000000000000000000000000000000001" + .parse() + .unwrap(), + destination_chain: "ethereum".parse().unwrap(), + destination_address: "destination-address".parse().unwrap(), + payload_hash: H256::from_slice(&[ + // keccak256("hello") + 28u8, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, + 86, 217, 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, 200, + ]) + .into(), + }, + ], + } + } +} diff --git a/ampd/src/handlers/starknet_verify_verifier_set.rs b/ampd/src/handlers/starknet_verify_verifier_set.rs new file mode 100644 index 000000000..265b8a767 --- /dev/null +++ b/ampd/src/handlers/starknet_verify_verifier_set.rs @@ -0,0 +1,158 @@ +//! Module responsible for handling verification of verifier set changes. +//! It processes events related to verifier set, verifies them against the Starknet chain, +//! and manages the voting process for confirming these changes. + +use std::convert::TryInto; + +use async_trait::async_trait; +use axelar_wasm_std::msg_id::FieldElementAndEventIndex; +use axelar_wasm_std::voting::{PollId, Vote}; +use cosmrs::cosmwasm::MsgExecuteContract; +use cosmrs::tx::Msg; +use cosmrs::Any; +use error_stack::ResultExt; +use events::Error::EventTypeMismatch; +use events::Event; +use events_derive::try_from; +use multisig::verifier_set::VerifierSet; +use serde::Deserialize; +use tokio::sync::watch::Receiver; +use tracing::{info, info_span}; +use valuable::Valuable; +use voting_verifier::msg::ExecuteMsg; + +use crate::event_processor::EventHandler; +use crate::handlers::errors::Error; +use crate::starknet::json_rpc::StarknetClient; +use crate::starknet::verifier::verify_verifier_set; +use crate::types::TMAddress; + +#[derive(Deserialize, Debug)] +pub struct VerifierSetConfirmation { + pub message_id: FieldElementAndEventIndex, + pub verifier_set: VerifierSet, +} + +#[derive(Deserialize, Debug)] +#[try_from("wasm-verifier_set_poll_started")] +struct PollStartedEvent { + poll_id: PollId, + source_gateway_address: String, + verifier_set: VerifierSetConfirmation, + participants: Vec, + expires_at: u64, +} + +pub struct Handler +where + C: StarknetClient + Send + Sync, +{ + verifier: TMAddress, + voting_verifier_contract: TMAddress, + rpc_client: C, + latest_block_height: Receiver, +} + +impl Handler +where + C: StarknetClient + Send + Sync, +{ + /// Handler for verifying verifier set updates from Starknet + /// + /// # Type Parameters + /// * `C` - A Starknet client type that implements the [`StarknetClient`] trait + #[allow(dead_code)] + pub fn new( + verifier: TMAddress, + voting_verifier_contract: TMAddress, + rpc_client: C, + latest_block_height: Receiver, + ) -> Self { + Self { + verifier, + voting_verifier_contract, + rpc_client, + latest_block_height, + } + } + + fn vote_msg(&self, poll_id: PollId, vote: Vote) -> MsgExecuteContract { + MsgExecuteContract { + sender: self.verifier.as_ref().clone(), + contract: self.voting_verifier_contract.as_ref().clone(), + msg: serde_json::to_vec(&ExecuteMsg::Vote { + poll_id, + votes: vec![vote], + }) + .expect("vote msg should serialize"), + funds: vec![], + } + } +} + +#[async_trait] +impl EventHandler for Handler +where + C: StarknetClient + Send + Sync + 'static, +{ + type Err = Error; + + async fn handle(&self, event: &Event) -> error_stack::Result, Self::Err> { + if !event.is_from_contract(self.voting_verifier_contract.as_ref()) { + return Ok(vec![]); + } + + let PollStartedEvent { + poll_id, + source_gateway_address, + verifier_set, + expires_at, + participants, + } = match event.try_into() as error_stack::Result<_, _> { + Err(report) if matches!(report.current_context(), EventTypeMismatch(_)) => { + return Ok(vec![]) + } + event => event.change_context(Error::DeserializeEvent)?, + }; + + if !participants.contains(&self.verifier) { + return Ok(vec![]); + } + + if *self.latest_block_height.borrow() >= expires_at { + info!(poll_id = poll_id.to_string(), "skipping expired poll"); + return Ok(vec![]); + } + + let transaction_response = self + .rpc_client + .get_event_by_hash_signers_rotated(verifier_set.message_id.tx_hash.clone()) + .await + .unwrap(); + + let vote = info_span!( + "verify a new verifier set", + poll_id = poll_id.to_string(), + message_id = verifier_set.message_id.to_string(), + ) + .in_scope(|| { + info!("ready to verify verifier set in poll",); + + let vote = transaction_response.map_or(Vote::NotFound, |tx_receipt| { + verify_verifier_set(&tx_receipt.1, &verifier_set, &source_gateway_address) + }); + + info!( + vote = vote.as_value(), + "ready to vote for a new verifier set in poll" + ); + + vote + }); + + Ok(vec![self + .vote_msg(poll_id, vote) + .into_any() + .expect("vote msg should serialize")]) + } +} diff --git a/ampd/src/lib.rs b/ampd/src/lib.rs index 02cb5971c..381f7b0a5 100644 --- a/ampd/src/lib.rs +++ b/ampd/src/lib.rs @@ -14,6 +14,7 @@ use evm::json_rpc::EthereumClient; use multiversx_sdk::gateway::GatewayProxy; use queue::queued_broadcaster::QueuedBroadcaster; use router_api::ChainName; +use starknet_providers::jsonrpc::HttpTransport; use thiserror::Error; use tofnd::grpc::{Multisig, MultisigClient}; use tokio::signal::unix::{signal, SignalKind}; @@ -40,6 +41,7 @@ mod health_check; mod json_rpc; mod mvx; mod queue; +pub(crate) mod starknet; mod stellar; mod sui; mod tm_client; @@ -389,6 +391,22 @@ where ), event_processor_config.clone(), ), + handlers::config::Config::StarknetMsgVerifier { + cosmwasm_contract, + rpc_url, + } => self.create_handler_task( + "starknet-msg-verifier", + handlers::starknet_verify_msg::Handler::new( + verifier.clone(), + cosmwasm_contract, + starknet::json_rpc::Client::new_with_transport(HttpTransport::new( + &rpc_url, + )) + .unwrap(), + self.block_height_monitor.latest_block_height(), + ), + event_processor_config.clone(), + ), }; self.event_processor = self.event_processor.add_task(task); } diff --git a/ampd/src/starknet/json_rpc.rs b/ampd/src/starknet/json_rpc.rs new file mode 100644 index 000000000..548b6a7c4 --- /dev/null +++ b/ampd/src/starknet/json_rpc.rs @@ -0,0 +1,1094 @@ +//! Verification implementation of Starknet JSON RPC client's verification of +//! transaction existence + +use async_trait::async_trait; +use error_stack::Report; +use mockall::automock; +use starknet_checked_felt::CheckedFelt; +use starknet_core::types::{ExecutionResult, Felt, FromStrError, TransactionReceipt}; +use starknet_providers::jsonrpc::JsonRpcTransport; +use starknet_providers::{JsonRpcClient, Provider, ProviderError}; +use starknet_types::events::contract_call::ContractCallEvent; +use starknet_types::events::signers_rotated::SignersRotatedEvent; +use thiserror::Error; + +type Result = error_stack::Result; + +#[derive(Debug, Error)] +pub enum StarknetClientError { + #[error(transparent)] + UrlParseError(#[from] url::ParseError), + #[error(transparent)] + JsonDeserializeError(#[from] serde_json::Error), + #[error("Failed to fetch tx receipt: {0}")] + FetchingReceipt(#[from] ProviderError), + #[error("Failed to create field element from string: {0}")] + FeltFromString(#[from] FromStrError), + #[error("Tx not successful")] + UnsuccessfulTx, +} + +/// Implementor of verification method(s) for given network using JSON RPC +/// client. +pub struct Client +where + T: JsonRpcTransport + Send + Sync + 'static, +{ + client: JsonRpcClient, +} + +impl Client +where + T: JsonRpcTransport + Send + Sync + 'static, +{ + /// Constructor. + /// Expects URL of any JSON RPC entry point of Starknet, which you can find + /// as constants in the `networks.rs` module + pub fn new_with_transport(transport: T) -> Result { + Ok(Client { + client: JsonRpcClient::new(transport), + }) + } +} + +/// A trait for fetching a ContractCall event, by a given tx_hash +/// and parsing parsing it into +/// `crate::starknet::events::contract_call::ContractCallEvent` +#[automock] +#[async_trait] +pub trait StarknetClient { + /// Attempts to fetch a ContractCall event, by a given `tx_hash`. + /// Returns a tuple `(tx_hash, event)` or a `StarknetClientError`. + async fn get_event_by_hash_contract_call( + &self, + tx_hash: CheckedFelt, + ) -> Result>; + + /// Attempts to fetch a SignersRotated event, by a given `tx_hash`. + /// Returns a tuple `(tx_hash, event)` or a `StarknetClientError`. + async fn get_event_by_hash_signers_rotated( + &self, + tx_hash: CheckedFelt, + ) -> Result>; +} + +#[async_trait] +impl StarknetClient for Client +where + T: JsonRpcTransport + Send + Sync + 'static, +{ + async fn get_event_by_hash_contract_call( + &self, + tx_hash: CheckedFelt, + ) -> Result> { + let receipt_with_block_info = self + .client + .get_transaction_receipt(tx_hash) + .await + .map_err(StarknetClientError::FetchingReceipt)?; + + if *receipt_with_block_info.receipt.execution_result() != ExecutionResult::Succeeded { + return Err(Report::new(StarknetClientError::UnsuccessfulTx)); + } + + let event: Option<(Felt, ContractCallEvent)> = match receipt_with_block_info.receipt { + TransactionReceipt::Invoke(tx) => { + // NOTE: There should be only one ContractCall event per gateway tx + tx.events + .iter() + .filter_map(|e| { + // NOTE: Here we ignore the error, because the event might + // not be ContractCall and that by itself is not erroneous behavior + if let Ok(cce) = ContractCallEvent::try_from(e.clone()) { + Some((tx.transaction_hash, cce)) + } else { + None + } + }) + .next() + } + TransactionReceipt::L1Handler(_) => None, + TransactionReceipt::Declare(_) => None, + TransactionReceipt::Deploy(_) => None, + TransactionReceipt::DeployAccount(_) => None, + }; + + Ok(event) + } + + /// Fetches a transaction receipt by hash and extracts a SignersRotatedEvent if present + /// + /// # Arguments + /// + /// * `tx_hash` - The hash of the transaction to fetch + /// + /// # Returns + /// + /// * `Ok(Some((tx_hash, SignersRotatedEvent)))` - If the transaction exists and contains a valid SignersRotatedEvent + /// * `Ok(None)` - If the transaction exists but contains no SignersRotatedEvent + /// * `Err(StarknetClientError)` - If there was an error fetching the receipt or the transaction failed + /// + /// # Errors + /// + /// Returns a `StarknetClientError` if: + /// * Failed to fetch the transaction receipt from the node + /// * The transaction execution was not successful + async fn get_event_by_hash_signers_rotated( + &self, + tx_hash: CheckedFelt, + ) -> Result> { + let receipt_with_block_info = self + .client + .get_transaction_receipt(tx_hash) + .await + .map_err(StarknetClientError::FetchingReceipt)?; + + if *receipt_with_block_info.receipt.execution_result() != ExecutionResult::Succeeded { + return Err(Report::new(StarknetClientError::UnsuccessfulTx)); + } + + let event: Option<(Felt, SignersRotatedEvent)> = match receipt_with_block_info.receipt { + TransactionReceipt::Invoke(tx) => tx + .events + .iter() + .filter_map(|e| { + if let Ok(sre) = SignersRotatedEvent::try_from(e.clone()) { + Some((tx.transaction_hash, sre)) + } else { + None + } + }) + .next(), + TransactionReceipt::L1Handler(_) => None, + TransactionReceipt::Declare(_) => None, + TransactionReceipt::Deploy(_) => None, + TransactionReceipt::DeployAccount(_) => None, + }; + + Ok(event) + } +} + +#[cfg(test)] +mod test { + + use std::str::FromStr; + + use axum::async_trait; + use ethers_core::types::H256; + use serde::de::DeserializeOwned; + use serde::Serialize; + use starknet_checked_felt::CheckedFelt; + use starknet_core::types::Felt; + use starknet_providers::jsonrpc::{ + HttpTransportError, JsonRpcMethod, JsonRpcResponse, JsonRpcTransport, + }; + use starknet_providers::{ProviderError, ProviderRequestData}; + use starknet_types::events::contract_call::ContractCallEvent; + use starknet_types::events::signers_rotated::SignersRotatedEvent; + + use super::{Client, StarknetClient, StarknetClientError}; + + #[tokio::test] + async fn invalid_signers_rotated_event_tx_fetch() { + let mock_client = + Client::new_with_transport(InvalidSignersRotatedEventMockTransport).unwrap(); + let contract_call_event = mock_client + .get_event_by_hash_signers_rotated( + CheckedFelt::try_from(&Felt::ONE.to_bytes_be()).unwrap(), + ) + .await; + + assert!(contract_call_event.unwrap().is_none()); + } + + #[tokio::test] + async fn deploy_account_tx_fetch() { + let mock_client = Client::new_with_transport(DeployAccountMockTransport).unwrap(); + let contract_call_event = mock_client + .get_event_by_hash_contract_call( + CheckedFelt::try_from(&Felt::ONE.to_bytes_be()).unwrap(), + ) + .await; + + assert!(contract_call_event.unwrap().is_none()); + } + + #[tokio::test] + async fn deploy_tx_fetch() { + let mock_client = Client::new_with_transport(DeployMockTransport).unwrap(); + let contract_call_event = mock_client + .get_event_by_hash_contract_call( + CheckedFelt::try_from(&Felt::ONE.to_bytes_be()).unwrap(), + ) + .await; + + assert!(contract_call_event.unwrap().is_none()); + } + + #[tokio::test] + async fn l1_handler_tx_fetch() { + let mock_client = Client::new_with_transport(L1HandlerMockTransport).unwrap(); + let contract_call_event = mock_client + .get_event_by_hash_contract_call( + CheckedFelt::try_from(&Felt::ONE.to_bytes_be()).unwrap(), + ) + .await; + + assert!(contract_call_event.unwrap().is_none()); + } + + #[tokio::test] + async fn declare_tx_fetch() { + let mock_client = Client::new_with_transport(DeclareMockTransport).unwrap(); + let contract_call_event = mock_client + .get_event_by_hash_contract_call( + CheckedFelt::try_from(&Felt::ONE.to_bytes_be()).unwrap(), + ) + .await; + + assert!(contract_call_event.unwrap().is_none()); + } + + #[tokio::test] + async fn invalid_contract_call_event_tx_fetch() { + let mock_client = + Client::new_with_transport(InvalidContractCallEventMockTransport).unwrap(); + let contract_call_event = mock_client + .get_event_by_hash_contract_call( + CheckedFelt::try_from(&Felt::ONE.to_bytes_be()).unwrap(), + ) + .await; + + assert!(contract_call_event.unwrap().is_none()); + } + + #[tokio::test] + async fn no_events_tx_fetch() { + let mock_client = Client::new_with_transport(NoEventsMockTransport).unwrap(); + let contract_call_event = mock_client + .get_event_by_hash_contract_call( + CheckedFelt::try_from(&Felt::ONE.to_bytes_be()).unwrap(), + ) + .await; + + assert!(contract_call_event.unwrap().is_none()); + } + + #[tokio::test] + async fn reverted_tx_fetch() { + let mock_client = Client::new_with_transport(RevertedMockTransport).unwrap(); + let contract_call_event = mock_client + .get_event_by_hash_contract_call( + CheckedFelt::try_from(&Felt::ONE.to_bytes_be()).unwrap(), + ) + .await; + + assert!(contract_call_event + .unwrap_err() + .contains::()); + } + + #[tokio::test] + async fn failing_tx_fetch() { + let mock_client = Client::new_with_transport(FailingMockTransport).unwrap(); + let contract_call_event = mock_client + .get_event_by_hash_contract_call( + CheckedFelt::try_from(&Felt::ONE.to_bytes_be()).unwrap(), + ) + .await; + + assert!(contract_call_event.is_err()); + } + + #[tokio::test] + async fn successful_signers_rotated_tx_fetch() { + let mock_client = Client::new_with_transport(ValidMockTransportSignersRotated).unwrap(); + let signers_rotated_event: (Felt, SignersRotatedEvent) = mock_client + .get_event_by_hash_signers_rotated( + CheckedFelt::try_from(&Felt::ONE.to_bytes_be()).unwrap(), + ) + .await + .unwrap() // unwrap the result + .unwrap(); // unwrap the option + + assert_eq!( + signers_rotated_event.0, + Felt::from_str("0x0000000000000000000000000000000000000000000000000000000000000001") + .unwrap() + ); + + let actual: SignersRotatedEvent = signers_rotated_event.1; + let expected: SignersRotatedEvent = SignersRotatedEvent { + from_address: "0x2".to_string(), + epoch: 1, + signers_hash: [ + 226, 62, 119, 4, 210, 79, 100, 110, 94, 54, 44, 97, 64, 122, 105, 210, 212, 32, 63, + 225, 67, 54, 50, 83, 200, 154, 39, 162, 106, 108, 184, 31, + ], + signers: starknet_types::events::signers_rotated::WeightedSigners { + signers: vec![starknet_types::events::signers_rotated::Signer { + signer: "0x3ec7d572a0fe479768ac46355651f22a982b99cc".to_string(), + weight: 1, + }], + threshold: 1, + nonce: [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 47, 228, 157, + ], + }, + }; + + assert_eq!(actual, expected); + } + + #[tokio::test] + async fn successful_call_contract_tx_fetch() { + let mock_client = Client::new_with_transport(ValidMockTransportCallContract).unwrap(); + let contract_call_event = mock_client + .get_event_by_hash_contract_call( + CheckedFelt::try_from(&Felt::ONE.to_bytes_be()).unwrap(), + ) + .await + .unwrap() // unwrap the result + .unwrap(); // unwrap the option + + assert_eq!( + contract_call_event.0, + Felt::from_str("0x0000000000000000000000000000000000000000000000000000000000000001") + .unwrap() + ); + assert_eq!( + contract_call_event.1, + ContractCallEvent { + from_contract_addr: + "0x0000000000000000000000000000000000000000000000000000000000000002".to_owned(), + destination_address: String::from("hello"), + destination_chain: String::from("destination_chain"), + source_address: Felt::from_str( + "0x00b3ff441a68610b30fd5e2abbf3a1548eb6ba6f3559f2862bf2dc757e5828ca" + ) + .unwrap(), + payload_hash: H256::from_slice(&[ + 28u8, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, 86, + 217, 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, 200 + ]) + } + ); + } + + struct FailingMockTransport; + + #[async_trait] + impl JsonRpcTransport for FailingMockTransport { + type Error = ProviderError; + + async fn send_requests( + &self, + _requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + unimplemented!(); + } + + async fn send_request( + &self, + _method: JsonRpcMethod, + _params: P, + ) -> Result, Self::Error> + where + P: Serialize + Send + Sync, + R: DeserializeOwned, + { + Err(ProviderError::RateLimited) + } + } + + struct L1HandlerMockTransport; + + #[async_trait] + impl JsonRpcTransport for L1HandlerMockTransport { + type Error = HttpTransportError; + + async fn send_requests( + &self, + _requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + unimplemented!(); + } + + async fn send_request( + &self, + _method: JsonRpcMethod, + _params: P, + ) -> Result, Self::Error> + where + P: Serialize + Send + Sync, + R: DeserializeOwned, + { + let response_mock = "{ + \"jsonrpc\": \"2.0\", + \"result\": { + \"type\": \"L1_HANDLER\", + \"transaction_hash\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"message_hash\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"actual_fee\": { + \"amount\": \"0x3062e4c46d4\", + \"unit\": \"WEI\" + }, + \"execution_status\": \"SUCCEEDED\", + \"finality_status\": \"ACCEPTED_ON_L2\", + \"block_hash\": \"0x5820e3a0aaceebdbda0b308fdf666eff64f263f6ed8ee74d6f78683b65a997b\", + \"block_number\": 637493, + \"messages_sent\": [], + \"events\": [], + \"execution_resources\": { + \"data_availability\": { + \"l1_data_gas\": 0, + \"l1_gas\": 0 + }, + \"memory_holes\": 1176, + \"pedersen_builtin_applications\": 34, + \"range_check_builtin_applications\": 1279, + \"steps\": 17574 + } + }, + \"id\": 0 +}"; + let parsed_response = serde_json::from_str(response_mock).map_err(Self::Error::Json)?; + + Ok(parsed_response) + } + } + + struct DeployAccountMockTransport; + + #[async_trait] + impl JsonRpcTransport for DeployAccountMockTransport { + type Error = HttpTransportError; + + async fn send_requests( + &self, + _requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + unimplemented!(); + } + + async fn send_request( + &self, + _method: JsonRpcMethod, + _params: P, + ) -> Result, Self::Error> + where + P: Serialize + Send + Sync, + R: DeserializeOwned, + { + let response_mock = "{ + \"jsonrpc\": \"2.0\", + \"result\": { + \"type\": \"DEPLOY_ACCOUNT\", + \"transaction_hash\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"contract_address\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"actual_fee\": { + \"amount\": \"0x3062e4c46d4\", + \"unit\": \"WEI\" + }, + \"execution_status\": \"SUCCEEDED\", + \"finality_status\": \"ACCEPTED_ON_L2\", + \"block_hash\": \"0x5820e3a0aaceebdbda0b308fdf666eff64f263f6ed8ee74d6f78683b65a997b\", + \"block_number\": 637493, + \"messages_sent\": [], + \"events\": [], + \"execution_resources\": { + \"data_availability\": { + \"l1_data_gas\": 0, + \"l1_gas\": 0 + }, + \"memory_holes\": 1176, + \"pedersen_builtin_applications\": 34, + \"range_check_builtin_applications\": 1279, + \"steps\": 17574 + } + }, + \"id\": 0 +}"; + let parsed_response = serde_json::from_str(response_mock).map_err(Self::Error::Json)?; + + Ok(parsed_response) + } + } + + struct DeployMockTransport; + + #[async_trait] + impl JsonRpcTransport for DeployMockTransport { + type Error = HttpTransportError; + + async fn send_requests( + &self, + _requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + unimplemented!(); + } + + async fn send_request( + &self, + _method: JsonRpcMethod, + _params: P, + ) -> Result, Self::Error> + where + P: Serialize + Send + Sync, + R: DeserializeOwned, + { + let response_mock = "{ + \"jsonrpc\": \"2.0\", + \"result\": { + \"type\": \"DEPLOY\", + \"transaction_hash\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"contract_address\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"actual_fee\": { + \"amount\": \"0x3062e4c46d4\", + \"unit\": \"WEI\" + }, + \"execution_status\": \"SUCCEEDED\", + \"finality_status\": \"ACCEPTED_ON_L2\", + \"block_hash\": \"0x5820e3a0aaceebdbda0b308fdf666eff64f263f6ed8ee74d6f78683b65a997b\", + \"block_number\": 637493, + \"messages_sent\": [], + \"events\": [], + \"execution_resources\": { + \"data_availability\": { + \"l1_data_gas\": 0, + \"l1_gas\": 0 + }, + \"memory_holes\": 1176, + \"pedersen_builtin_applications\": 34, + \"range_check_builtin_applications\": 1279, + \"steps\": 17574 + } + }, + \"id\": 0 +}"; + let parsed_response = serde_json::from_str(response_mock).map_err(Self::Error::Json)?; + + Ok(parsed_response) + } + } + + struct DeclareMockTransport; + + #[async_trait] + impl JsonRpcTransport for DeclareMockTransport { + type Error = HttpTransportError; + + async fn send_requests( + &self, + _requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + unimplemented!(); + } + + async fn send_request( + &self, + _method: JsonRpcMethod, + _params: P, + ) -> Result, Self::Error> + where + P: Serialize + Send + Sync, + R: DeserializeOwned, + { + let response_mock = "{ + \"jsonrpc\": \"2.0\", + \"result\": { + \"type\": \"DECLARE\", + \"transaction_hash\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"actual_fee\": { + \"amount\": \"0x3062e4c46d4\", + \"unit\": \"WEI\" + }, + \"execution_status\": \"SUCCEEDED\", + \"finality_status\": \"ACCEPTED_ON_L2\", + \"block_hash\": \"0x5820e3a0aaceebdbda0b308fdf666eff64f263f6ed8ee74d6f78683b65a997b\", + \"block_number\": 637493, + \"messages_sent\": [], + \"events\": [], + \"execution_resources\": { + \"data_availability\": { + \"l1_data_gas\": 0, + \"l1_gas\": 0 + }, + \"memory_holes\": 1176, + \"pedersen_builtin_applications\": 34, + \"range_check_builtin_applications\": 1279, + \"steps\": 17574 + } + }, + \"id\": 0 +}"; + let parsed_response = serde_json::from_str(response_mock).map_err(Self::Error::Json)?; + + Ok(parsed_response) + } + } + + struct NoEventsMockTransport; + + #[async_trait] + impl JsonRpcTransport for NoEventsMockTransport { + type Error = HttpTransportError; + + async fn send_requests( + &self, + _requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + unimplemented!(); + } + + async fn send_request( + &self, + _method: JsonRpcMethod, + _params: P, + ) -> Result, Self::Error> + where + P: Serialize + Send + Sync, + R: DeserializeOwned, + { + let response_mock = "{ + \"jsonrpc\": \"2.0\", + \"result\": { + \"type\": \"INVOKE\", + \"transaction_hash\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"actual_fee\": { + \"amount\": \"0x3062e4c46d4\", + \"unit\": \"WEI\" + }, + \"execution_status\": \"SUCCEEDED\", + \"finality_status\": \"ACCEPTED_ON_L2\", + \"block_hash\": \"0x5820e3a0aaceebdbda0b308fdf666eff64f263f6ed8ee74d6f78683b65a997b\", + \"block_number\": 637493, + \"messages_sent\": [], + \"events\": [], + \"execution_resources\": { + \"data_availability\": { + \"l1_data_gas\": 0, + \"l1_gas\": 0 + }, + \"memory_holes\": 1176, + \"pedersen_builtin_applications\": 34, + \"range_check_builtin_applications\": 1279, + \"steps\": 17574 + } + }, + \"id\": 0 +}"; + let parsed_response = serde_json::from_str(response_mock).map_err(Self::Error::Json)?; + + Ok(parsed_response) + } + } + + struct RevertedMockTransport; + + #[async_trait] + impl JsonRpcTransport for RevertedMockTransport { + type Error = HttpTransportError; + + async fn send_requests( + &self, + _requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + unimplemented!(); + } + + async fn send_request( + &self, + _method: JsonRpcMethod, + _params: P, + ) -> Result, Self::Error> + where + P: Serialize + Send + Sync, + R: DeserializeOwned, + { + let response_mock = "{ + \"jsonrpc\": \"2.0\", + \"result\": { + \"type\": \"INVOKE\", + \"transaction_hash\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"actual_fee\": { + \"amount\": \"0x3062e4c46d4\", + \"unit\": \"WEI\" + }, + \"execution_status\": \"REVERTED\", + \"finality_status\": \"ACCEPTED_ON_L2\", + \"block_hash\": \"0x5820e3a0aaceebdbda0b308fdf666eff64f263f6ed8ee74d6f78683b65a997b\", + \"block_number\": 637493, + \"messages_sent\": [], + \"events\": [], + \"execution_resources\": { + \"data_availability\": { + \"l1_data_gas\": 0, + \"l1_gas\": 0 + }, + \"memory_holes\": 1176, + \"pedersen_builtin_applications\": 34, + \"range_check_builtin_applications\": 1279, + \"steps\": 17574 + } + }, + \"id\": 0 +}"; + let parsed_response = serde_json::from_str(response_mock).map_err(Self::Error::Json)?; + + Ok(parsed_response) + } + } + + struct InvalidSignersRotatedEventMockTransport; + + #[async_trait] + impl JsonRpcTransport for InvalidSignersRotatedEventMockTransport { + type Error = HttpTransportError; + + async fn send_requests( + &self, + _requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + unimplemented!(); + } + + async fn send_request( + &self, + _method: JsonRpcMethod, + _params: P, + ) -> Result, Self::Error> + where + P: Serialize + Send + Sync, + R: DeserializeOwned, + { + // garbage "data" + let response_mock = "{ + \"jsonrpc\": \"2.0\", + \"result\": { + \"type\": \"INVOKE\", + \"transaction_hash\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"actual_fee\": { + \"amount\": \"0x3062e4c46d4\", + \"unit\": \"WEI\" + }, + \"execution_status\": \"SUCCEEDED\", + \"finality_status\": \"ACCEPTED_ON_L2\", + \"block_hash\": \"0x5820e3a0aaceebdbda0b308fdf666eff64f263f6ed8ee74d6f78683b65a997b\", + \"block_number\": 637493, + \"messages_sent\": [], + \"events\": [ + { + \"from_address\": \"0x000000000000000000000000000000000000000000000000000000000000002\", + \"keys\": [ + \"0x01815547484542c49542242a23bc0a1b762af99232f38c0417050825aea8fc93\", + \"0x0268929df65ee595bb8592323f981351efdc467d564effc6d2e54d2e666e43ca\", + \"0x01\", + \"0xd4203fe143363253c89a27a26a6cb81f\", + \"0xe23e7704d24f646e5e362c61407a69d2\" + ], + \"data\": [ + \"0xb3ff441a68610b30fd5e2abbf3a1548eb6ba6f3559f2862bf2dc757e5828ca\", + \"0x0000000000000000000000000000000000000000000000000000000000000000\", + \"0x00000000000000000000000000000000000000000000000000000068656c6c6f\", + \"0x0000000000000000000000000000000000000000000000000000000000000001\", + \"0x0000000000000000000000000000000056d9517b9c948127319a09a7a36deac8\", + \"0x000000000000000000000000000000001c8aff950685c2ed4bc3174f3472287b\", + \"0x0000000000000000000000000000000000000000000000000000000000000005\", + \"0x0000000000000000000000000000000000000000000000000000000000000068\", + \"0x0000000000000000000000000000000000000000000000000000000000000065\", + \"0x000000000000000000000000000000000000000000000000000000000000006c\", + \"0x000000000000000000000000000000000000000000000000000000000000006c\", + \"0x000000000000000000000000000000000000000000000000000000000000006f\" + ] + } + ], + \"execution_resources\": { + \"data_availability\": { + \"l1_data_gas\": 0, + \"l1_gas\": 0 + }, + \"memory_holes\": 1176, + \"pedersen_builtin_applications\": 34, + \"range_check_builtin_applications\": 1279, + \"steps\": 17574 + } + }, + \"id\": 0 +}"; + let parsed_response = serde_json::from_str(response_mock).map_err(Self::Error::Json)?; + + Ok(parsed_response) + } + } + + struct InvalidContractCallEventMockTransport; + + #[async_trait] + impl JsonRpcTransport for InvalidContractCallEventMockTransport { + type Error = HttpTransportError; + + async fn send_requests( + &self, + _requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + unimplemented!(); + } + + async fn send_request( + &self, + _method: JsonRpcMethod, + _params: P, + ) -> Result, Self::Error> + where + P: Serialize + Send + Sync, + R: DeserializeOwned, + { + // 1 byte for the pending_word, instead of 5 + let response_mock = "{ + \"jsonrpc\": \"2.0\", + \"result\": { + \"type\": \"INVOKE\", + \"transaction_hash\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"actual_fee\": { + \"amount\": \"0x3062e4c46d4\", + \"unit\": \"WEI\" + }, + \"execution_status\": \"SUCCEEDED\", + \"finality_status\": \"ACCEPTED_ON_L2\", + \"block_hash\": \"0x5820e3a0aaceebdbda0b308fdf666eff64f263f6ed8ee74d6f78683b65a997b\", + \"block_number\": 637493, + \"messages_sent\": [], + \"events\": [ + { + \"from_address\": \"0x000000000000000000000000000000000000000000000000000000000000002\", + \"keys\": [ + \"0x01815547484542c49542242a23bc0a1b762af99232f38c0417050825aea8fc93\", + \"0x0268929df65ee595bb8592323f981351efdc467d564effc6d2e54d2e666e43ca\", + \"0x01\", + \"0xd4203fe143363253c89a27a26a6cb81f\", + \"0xe23e7704d24f646e5e362c61407a69d2\" + ], + \"data\": [ + \"0xb3ff441a68610b30fd5e2abbf3a1548eb6ba6f3559f2862bf2dc757e5828ca\", + \"0x0000000000000000000000000000000000000000000000000000000000000000\", + \"0x00000000000000000000000000000000000000000000000000000068656c6c6f\", + \"0x0000000000000000000000000000000000000000000000000000000000000001\", + \"0x0000000000000000000000000000000056d9517b9c948127319a09a7a36deac8\", + \"0x000000000000000000000000000000001c8aff950685c2ed4bc3174f3472287b\", + \"0x0000000000000000000000000000000000000000000000000000000000000005\", + \"0x0000000000000000000000000000000000000000000000000000000000000068\", + \"0x0000000000000000000000000000000000000000000000000000000000000065\", + \"0x000000000000000000000000000000000000000000000000000000000000006c\", + \"0x000000000000000000000000000000000000000000000000000000000000006c\", + \"0x000000000000000000000000000000000000000000000000000000000000006f\" + ] + } + ], + \"execution_resources\": { + \"data_availability\": { + \"l1_data_gas\": 0, + \"l1_gas\": 0 + }, + \"memory_holes\": 1176, + \"pedersen_builtin_applications\": 34, + \"range_check_builtin_applications\": 1279, + \"steps\": 17574 + } + }, + \"id\": 0 +}"; + let parsed_response = serde_json::from_str(response_mock).map_err(Self::Error::Json)?; + + Ok(parsed_response) + } + } + + struct ValidMockTransportSignersRotated; + + #[async_trait] + impl JsonRpcTransport for ValidMockTransportSignersRotated { + type Error = HttpTransportError; + + async fn send_requests( + &self, + _requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + unimplemented!(); + } + + async fn send_request( + &self, + _method: JsonRpcMethod, + _params: P, + ) -> Result, Self::Error> + where + P: Serialize + Send + Sync, + R: DeserializeOwned, + { + let response_mock = "{ + \"jsonrpc\": \"2.0\", + \"result\": { + \"type\": \"INVOKE\", + \"transaction_hash\": \"0x0000000000000000000000000000000000000000000000000000000000000001\", + \"actual_fee\": { + \"amount\": \"0x3062e4c46d4\", + \"unit\": \"WEI\" + }, + \"execution_status\": \"SUCCEEDED\", + \"finality_status\": \"ACCEPTED_ON_L2\", + \"block_hash\": \"0x5820e3a0aaceebdbda0b308fdf666eff64f263f6ed8ee74d6f78683b65a997b\", + \"block_number\": 637493, + \"messages_sent\": [], + \"events\": [ + { + \"from_address\": \"0x0000000000000000000000000000000000000000000000000000000000000002\", + \"keys\": [ + \"0x01815547484542c49542242a23bc0a1b762af99232f38c0417050825aea8fc93\", + \"0x0268929df65ee595bb8592323f981351efdc467d564effc6d2e54d2e666e43ca\", + \"0x01\", + \"0xd4203fe143363253c89a27a26a6cb81f\", + \"0xe23e7704d24f646e5e362c61407a69d2\" + ], + \"data\": [ + \"0x01\", + \"0x3ec7d572a0fe479768ac46355651f22a982b99cc\", + \"0x01\", + \"0x01\", + \"0x2fe49d\", + \"0x00\" + ] + } + ], + \"execution_resources\": { + \"data_availability\": { + \"l1_data_gas\": 0, + \"l1_gas\": 0 + }, + \"memory_holes\": 1176, + \"pedersen_builtin_applications\": 34, + \"range_check_builtin_applications\": 1279, + \"steps\": 17574 + } + }, + \"id\": 0 +}"; + let parsed_response = serde_json::from_str(response_mock).map_err(Self::Error::Json)?; + + Ok(parsed_response) + } + } + + struct ValidMockTransportCallContract; + + #[async_trait] + impl JsonRpcTransport for ValidMockTransportCallContract { + type Error = HttpTransportError; + + async fn send_requests( + &self, + _requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + unimplemented!(); + } + + async fn send_request( + &self, + _method: JsonRpcMethod, + _params: P, + ) -> Result, Self::Error> + where + P: Serialize + Send + Sync, + R: DeserializeOwned, + { + let response_mock = "{ + \"jsonrpc\": \"2.0\", + \"result\": { + \"type\": \"INVOKE\", + \"transaction_hash\": \"0x0000000000000000000000000000000000000000000000000000000000000001\", + \"actual_fee\": { + \"amount\": \"0x3062e4c46d4\", + \"unit\": \"WEI\" + }, + \"execution_status\": \"SUCCEEDED\", + \"finality_status\": \"ACCEPTED_ON_L2\", + \"block_hash\": \"0x5820e3a0aaceebdbda0b308fdf666eff64f263f6ed8ee74d6f78683b65a997b\", + \"block_number\": 637493, + \"messages_sent\": [], + \"events\": [ + { + \"from_address\": \"0x0000000000000000000000000000000000000000000000000000000000000002\", + \"keys\": [ + \"0x034d074b86d78f064ec0a29639fcfab989c7a3ea6343653633624b2df9ec08f6\", + \"0x00000000000000000000000000000064657374696e6174696f6e5f636861696e\" + ], + \"data\": [ + \"0xb3ff441a68610b30fd5e2abbf3a1548eb6ba6f3559f2862bf2dc757e5828ca\", + \"0x0000000000000000000000000000000000000000000000000000000000000000\", + \"0x00000000000000000000000000000000000000000000000000000068656c6c6f\", + \"0x0000000000000000000000000000000000000000000000000000000000000005\", + \"0x0000000000000000000000000000000056d9517b9c948127319a09a7a36deac8\", + \"0x000000000000000000000000000000001c8aff950685c2ed4bc3174f3472287b\", + \"0x0000000000000000000000000000000000000000000000000000000000000005\", + \"0x0000000000000000000000000000000000000000000000000000000000000068\", + \"0x0000000000000000000000000000000000000000000000000000000000000065\", + \"0x000000000000000000000000000000000000000000000000000000000000006c\", + \"0x000000000000000000000000000000000000000000000000000000000000006c\", + \"0x000000000000000000000000000000000000000000000000000000000000006f\" + ] + } + ], + \"execution_resources\": { + \"data_availability\": { + \"l1_data_gas\": 0, + \"l1_gas\": 0 + }, + \"memory_holes\": 1176, + \"pedersen_builtin_applications\": 34, + \"range_check_builtin_applications\": 1279, + \"steps\": 17574 + } + }, + \"id\": 0 +}"; + let parsed_response = serde_json::from_str(response_mock).map_err(Self::Error::Json)?; + + Ok(parsed_response) + } + } +} diff --git a/ampd/src/starknet/mod.rs b/ampd/src/starknet/mod.rs new file mode 100644 index 000000000..f5a7d0a3f --- /dev/null +++ b/ampd/src/starknet/mod.rs @@ -0,0 +1,2 @@ +pub mod json_rpc; +pub mod verifier; diff --git a/ampd/src/starknet/verifier.rs b/ampd/src/starknet/verifier.rs new file mode 100644 index 000000000..e122ddca6 --- /dev/null +++ b/ampd/src/starknet/verifier.rs @@ -0,0 +1,348 @@ +use axelar_wasm_std::voting::Vote; +use cosmwasm_std::HexBinary; +use starknet_core::types::Felt; +use starknet_types::events::contract_call::ContractCallEvent; +use starknet_types::events::signers_rotated::SignersRotatedEvent; + +use crate::handlers::starknet_verify_msg::Message; +use crate::handlers::starknet_verify_verifier_set::VerifierSetConfirmation; + +/// Attempts to fetch the tx provided in `axl_msg.tx_id`. +/// If successful, extracts and parses the ContractCall event +/// and compares it to the message from the relayer (via PollStarted event). +/// Also checks if the source_gateway_address with which +/// the voting verifier has been instantiated is the same address from +/// which the ContractCall event is coming. +pub fn verify_msg( + starknet_event: &ContractCallEvent, + msg: &Message, + source_gateway_address: &str, +) -> Vote { + if *starknet_event == *msg && starknet_event.from_contract_addr == source_gateway_address { + Vote::SucceededOnChain + } else { + Vote::NotFound + } +} + +impl PartialEq for ContractCallEvent { + fn eq(&self, axl_msg: &Message) -> bool { + Felt::from(axl_msg.source_address.clone()) == self.source_address + && axl_msg.destination_chain == self.destination_chain + && axl_msg.destination_address == self.destination_address + && axl_msg.payload_hash == self.payload_hash + } +} + +pub fn verify_verifier_set( + event: &SignersRotatedEvent, + confirmation: &VerifierSetConfirmation, + source_gateway_address: &str, +) -> Vote { + if event.signers.nonce != [0_u8; 32] + && event == confirmation + && event.from_address == source_gateway_address + { + Vote::SucceededOnChain + } else { + Vote::NotFound + } +} + +impl PartialEq for SignersRotatedEvent { + fn eq(&self, confirmation: &VerifierSetConfirmation) -> bool { + let expected = &confirmation.verifier_set; + + // Convert and sort expected signers + let mut expected_signers = expected + .signers + .values() + .map(|signer| { + if let multisig::key::PublicKey::Ecdsa(pubkey) = &signer.pub_key { + (pubkey.clone(), signer.weight.u128()) + } else { + // Skip non-ECDSA keys + (HexBinary::from_hex("").unwrap(), 0) + } + }) + .collect::>(); + expected_signers.sort(); + + // Convert and sort actual signers from the event + let mut actual_signers = self + .signers + .signers + .iter() + .map(|signer| (HexBinary::from_hex(&signer.signer).unwrap(), signer.weight)) + .collect::>(); + actual_signers.sort(); + + // Compare signers, threshold, and created_at timestamp + actual_signers == expected_signers + && self.signers.threshold == expected.threshold.u128() + && self.epoch == expected.created_at + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + use std::str::FromStr; + + use axelar_wasm_std::msg_id::FieldElementAndEventIndex; + use axelar_wasm_std::voting::Vote; + use cosmwasm_std::{Addr, HexBinary, Uint128}; + use ethers_core::types::H256; + use multisig::msg::Signer; + use multisig::verifier_set::VerifierSet; + use router_api::ChainName; + use starknet_checked_felt::CheckedFelt; + use starknet_core::types::Felt; + use starknet_types::events::contract_call::ContractCallEvent; + use starknet_types::events::signers_rotated::{ + Signer as StarknetSigner, SignersRotatedEvent, WeightedSigners, + }; + + use super::verify_msg; + use crate::handlers::starknet_verify_msg::Message; + use crate::handlers::starknet_verify_verifier_set::VerifierSetConfirmation; + use crate::starknet::verifier::verify_verifier_set; + + // "hello" as payload + // "hello" as destination address + // "some_contract_address" as source address + // "destination_chain" as destination_chain + fn mock_valid_event() -> ContractCallEvent { + ContractCallEvent { + from_contract_addr: String::from( + "0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e", + ), + destination_address: String::from("destination_address"), + destination_chain: String::from("ethereum"), + source_address: Felt::from_str( + "0x00b3ff441a68610b30fd5e2abbf3a1548eb6ba6f3559f2862bf2dc757e5828ca", + ) + .unwrap(), + payload_hash: H256::from_slice(&[ + 28, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, 86, 217, + 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, 200, + ]), + } + } + + fn mock_valid_message() -> Message { + Message { + message_id: FieldElementAndEventIndex { + tx_hash: CheckedFelt::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000001", + ) + .unwrap(), + event_index: 0, + }, + destination_address: String::from("destination_address"), + destination_chain: ChainName::from_str("ethereum").unwrap(), + source_address: CheckedFelt::from_str( + "0x00b3ff441a68610b30fd5e2abbf3a1548eb6ba6f3559f2862bf2dc757e5828ca", + ) + .unwrap(), + payload_hash: H256::from_slice(&[ + 28, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, 86, 217, + 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, 200, + ]), + } + } + + #[test] + fn shoud_fail_different_source_gw() { + assert_eq!( + verify_msg( + &mock_valid_event(), + &mock_valid_message(), + &String::from("different"), + ), + Vote::NotFound + ) + } + + #[test] + fn shoud_fail_different_event_fields() { + let msg = mock_valid_message(); + let source_gw_address = + String::from("0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e"); + + let mut event = mock_valid_event(); + event.destination_address = String::from("different"); + assert_eq!(verify_msg(&event, &msg, &source_gw_address), Vote::NotFound); + + let mut event = { mock_valid_event() }; + event.destination_chain = String::from("different"); + assert_eq!(verify_msg(&event, &msg, &source_gw_address), Vote::NotFound); + + let mut event = { mock_valid_event() }; + event.source_address = Felt::THREE.into(); + assert_eq!(verify_msg(&event, &msg, &source_gw_address), Vote::NotFound); + + let mut event = { mock_valid_event() }; + event.payload_hash = H256::from_slice(&[ + 28u8, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, 86, 217, 81, + 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, + 1, // last byte is different + ]); + assert_eq!(verify_msg(&event, &msg, &source_gw_address), Vote::NotFound); + } + + #[test] + fn shoud_fail_different_msg_fields() { + let event = mock_valid_event(); + let source_gw_address = + String::from("0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e"); + + let mut msg = mock_valid_message(); + msg.destination_address = String::from("different"); + assert_eq!(verify_msg(&event, &msg, &source_gw_address), Vote::NotFound); + + let mut msg = { mock_valid_message() }; + msg.destination_chain = ChainName::from_str("avalanche").unwrap(); + assert_eq!(verify_msg(&event, &msg, &source_gw_address), Vote::NotFound); + + let mut msg = { mock_valid_message() }; + msg.source_address = CheckedFelt::try_from(&Felt::THREE.to_bytes_be()).unwrap(); + assert_eq!(verify_msg(&event, &msg, &source_gw_address), Vote::NotFound); + + let mut msg = { mock_valid_message() }; + msg.payload_hash = H256::from_slice(&[ + 28u8, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, 86, 217, 81, + 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, + 1, // last byte is different + ]); + assert_eq!(verify_msg(&event, &msg, &source_gw_address), Vote::NotFound); + } + + #[test] + fn shoud_verify_event() { + assert_eq!( + verify_msg( + &mock_valid_event(), + &mock_valid_message(), + &String::from("0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e"), + ), + Vote::SucceededOnChain + ) + } + + /// Verifier set - signers rotated + + fn mock_valid_confirmation_signers_rotated() -> VerifierSetConfirmation { + VerifierSetConfirmation { + verifier_set: mock_valid_verifier_set_signers_rotated(), + message_id: FieldElementAndEventIndex { + tx_hash: CheckedFelt::try_from(&[0_u8; 32]).unwrap(), + event_index: 0, + }, + } + } + + fn mock_valid_verifier_set_signers_rotated() -> VerifierSet { + let signers = vec![Signer { + address: Addr::unchecked("axelarvaloper1x86a8prx97ekkqej2x636utrdu23y8wupp9gk5"), + weight: Uint128::from(10u128), + pub_key: multisig::key::PublicKey::Ecdsa( + HexBinary::from_hex( + "03d123ce370b163acd576be0e32e436bb7e63262769881d35fa3573943bf6c6f81", + ) + .unwrap(), + ), + }]; + + let mut btree_signers = BTreeMap::new(); + for signer in signers { + btree_signers.insert(signer.address.clone().to_string(), signer); + } + + VerifierSet { + signers: btree_signers, + threshold: Uint128::one(), + created_at: 1, + } + } + + fn mock_valid_event_signers_rotated() -> SignersRotatedEvent { + SignersRotatedEvent { + // should be the same as the source gw address + from_address: String::from( + "0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e", + ), + epoch: 1, + signers_hash: [8_u8; 32], + signers: WeightedSigners { + signers: vec![StarknetSigner { + signer: String::from( + "03d123ce370b163acd576be0e32e436bb7e63262769881d35fa3573943bf6c6f81", + ), + weight: Uint128::from(10u128).into(), + }], + threshold: Uint128::one().into(), + nonce: [7_u8; 32], + }, + } + } + + fn mock_second_valid_event_signers_rotated() -> SignersRotatedEvent { + SignersRotatedEvent { + from_address: String::from( + "0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e", + ), + epoch: 1, + signers_hash: [8_u8; 32], + signers: WeightedSigners { + signers: vec![StarknetSigner { + signer: String::from( + "028584592624e742ba154c02df4c0b06e4e8a957ba081083ea9fe5309492aa6c7b", + ), + weight: Uint128::from(10u128).into(), + }], + threshold: Uint128::one().into(), + nonce: [7_u8; 32], + }, + } + } + + #[test] + fn should_verify_verifier_set() { + let source_gw_address = + String::from("0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e"); + let confirmation = mock_valid_confirmation_signers_rotated(); + let event = mock_valid_event_signers_rotated(); + + assert_eq!( + verify_verifier_set(&event, &confirmation, &source_gw_address), + Vote::SucceededOnChain + ); + } + + #[test] + fn should_not_verify_verifier_set_if_nonce_zero() { + let mut event = mock_valid_event_signers_rotated(); + event.signers.nonce = [0_u8; 32]; + let gateway_address = + String::from("0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e"); + let confirmation = mock_valid_confirmation_signers_rotated(); + + assert_eq!( + verify_verifier_set(&event, &confirmation, &gateway_address), + Vote::NotFound + ); + } + #[test] + fn shoud_not_verify_verifier_set_if_signers_mismatch() { + let source_gw_address = + String::from("0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e"); + let event = mock_second_valid_event_signers_rotated(); + let confirmation = mock_valid_confirmation_signers_rotated(); + + assert_eq!( + verify_verifier_set(&event, &confirmation, &source_gw_address), + Vote::NotFound + ); + } +} diff --git a/packages/axelar-wasm-std/Cargo.toml b/packages/axelar-wasm-std/Cargo.toml index f8178d449..4a5e5f605 100644 --- a/packages/axelar-wasm-std/Cargo.toml +++ b/packages/axelar-wasm-std/Cargo.toml @@ -49,17 +49,19 @@ serde_with = { version = "3.11.0", features = ["macros"] } sha3 = { workspace = true } starknet-checked-felt = { workspace = true } stellar-xdr = { workspace = true } +starknet-types-core = { workspace = true } strum = { workspace = true } sui-types = { workspace = true } thiserror = { workspace = true } valuable = { version = "0.1.0", features = ["derive"] } +crypto-bigint = { version = "0.5.5", features = ["rand_core"] } [dev-dependencies] assert_ok = { workspace = true } cw-multi-test = { workspace = true } goldie = { workspace = true } -hex = { version = "0.4.3", default-features = false } rand = { workspace = true } +hex = { version = "0.4.3", default-features = false } [lints] workspace = true diff --git a/packages/axelar-wasm-std/src/utils.rs b/packages/axelar-wasm-std/src/utils.rs index e45c0695a..292deb920 100644 --- a/packages/axelar-wasm-std/src/utils.rs +++ b/packages/axelar-wasm-std/src/utils.rs @@ -1,3 +1,6 @@ +use crypto_bigint::U256; +use starknet_types_core::felt::Felt; + pub trait TryMapExt { type Monad; fn try_map(self, func: impl FnMut(T) -> Result) -> Result, E>; @@ -19,6 +22,22 @@ impl TryMapExt for Vec { } } +/// since the `Felt` type doesn't error on overflow, we have to implement that check +pub fn does_felt_overflow_from_slice(felt_hex_slice: &[u8]) -> bool { + if felt_hex_slice.len() > 32 { + return true; + } + let felt_max_hex_str = format!("{:064x}", Felt::MAX); + U256::from_be_slice(felt_hex_slice) > U256::from_be_hex(&felt_max_hex_str) +} + +/// since the `Felt` type doesn't error on overflow, we have to implement that check +pub fn does_felt_overflow_from_str(felt_hex_str: &str) -> bool { + let felt_hex_str = felt_hex_str.trim_start_matches("0x"); + let felt_max_hex_str = format!("{:064x}", Felt::MAX); + U256::from_be_hex(felt_hex_str) > U256::from_be_hex(&felt_max_hex_str) +} + #[cfg(test)] mod test { use super::*; diff --git a/packages/starknet-types/Cargo.toml b/packages/starknet-types/Cargo.toml new file mode 100644 index 000000000..f6bcf72f0 --- /dev/null +++ b/packages/starknet-types/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "starknet-types" +version = "1.0.0" +rust-version.workspace = true +edition.workspace = true + +[dependencies] +axelar-wasm-std = { workspace = true, features = ["derive"] } +cosmwasm-std = { workspace = true } +router-api = { workspace = true } +ethers-core = { workspace = true } +starknet-checked-felt = { workspace = true } +starknet-core = { workspace = true } +starknet-types-core = { workspace = true } +error-stack = { workspace = true } +thiserror = { workspace = true } +itertools = { workspace = true } +hex = { workspace = true } +tokio = { version = "1", features = [ + "rt", + "signal", + "rt-multi-thread", + "macros", +] } +rand = { workspace = true } +# futures = { workspace = true } + +[lints] +workspace = true diff --git a/packages/starknet-types/src/error.rs b/packages/starknet-types/src/error.rs new file mode 100644 index 000000000..a8e5bc2b8 --- /dev/null +++ b/packages/starknet-types/src/error.rs @@ -0,0 +1,7 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum Error { + #[error("invalid starknet address")] + InvalidAddress, +} diff --git a/packages/starknet-types/src/events.rs b/packages/starknet-types/src/events.rs new file mode 100644 index 000000000..dfe883063 --- /dev/null +++ b/packages/starknet-types/src/events.rs @@ -0,0 +1,2 @@ +pub mod contract_call; +pub mod signers_rotated; diff --git a/packages/starknet-types/src/events/contract_call.rs b/packages/starknet-types/src/events/contract_call.rs new file mode 100644 index 000000000..d87c45f8a --- /dev/null +++ b/packages/starknet-types/src/events/contract_call.rs @@ -0,0 +1,229 @@ +use ethers_core::types::H256; +use starknet_core::types::Felt; +use starknet_core::utils::{parse_cairo_short_string, ParseCairoShortStringError}; +use thiserror::Error; + +use crate::types::byte_array::{ByteArray, ByteArrayError}; + +/// This is the event emitted by the gateway cairo contract on Starknet, +/// when the call_contract method is called from a third party. +#[derive(Debug, PartialEq, Clone)] +pub struct ContractCallEvent { + pub from_contract_addr: String, + pub destination_address: String, + pub destination_chain: String, + pub source_address: Felt, + pub payload_hash: H256, +} + +/// An error, representing failure to convert/parse a starknet event +/// to some specific event. +#[derive(Error, Debug)] +pub enum ContractCallError { + #[error("Invalid ContractCall event: {0}")] + InvalidEvent(String), + #[error("Cairo short string parse error: {0}")] + Cairo(#[from] ParseCairoShortStringError), + #[error("Failed felt conversion: {0}")] + TryFromConversion(String), + #[error("Event data/keys array index is out of bounds")] + OutOfBound, + #[error("ByteArray type error: {0}")] + ByteArray(#[from] ByteArrayError), + #[error("missing payload data for transaction")] + MissingPayloadData, + #[error("missing keys for transaction")] + MissingKeys, +} + +impl TryFrom for ContractCallEvent { + type Error = ContractCallError; + + fn try_from(event: starknet_core::types::Event) -> Result { + if event.data.is_empty() { + return Err(ContractCallError::MissingPayloadData); + } + if event.keys.is_empty() { + return Err(ContractCallError::MissingKeys); + } + + // `event.from_address` is the contract address, which emitted the event + let from_contract_addr = format!("0x{}", hex::encode(event.from_address.to_bytes_be())); + + // destination_chain is the second key in the event keys list (the first key + // defined from the event) + // + // This field, should not exceed 252 bits (a felt's length) + let destination_chain = parse_cairo_short_string(&event.keys[1])?; + + // source_address represents the original caller of the `call_contract` gateway + // method. It is the first field in data, by the order defined in the + // event. + let source_address = event.data[0]; + + // destination_contract_address (ByteArray) is composed of FieldElements + // from the second element to elemet X. + let destination_address_chunks_count_felt = event.data[1]; + let da_chunks_count: usize = u8::try_from(destination_address_chunks_count_felt) + .map_err(|err| ContractCallError::TryFromConversion(err.to_string()))? + .into(); + + // It's + 3, because we need to offset the 0th element, pending_word and + // pending_word_count, in addition to all chunks (da_chunks_count_usize) + let da_elements_start_index: usize = 1; + let da_elements_end_index: usize = da_chunks_count.wrapping_add(3); + let destination_address_byte_array: ByteArray = ByteArray::try_from( + event + .data + .get(da_elements_start_index..=da_elements_end_index) + .ok_or(ContractCallError::OutOfBound)? + .to_vec(), + )?; + let destination_address = destination_address_byte_array.try_to_string()?; + + // payload_hash is a keccak256, which is a combination of two felts (chunks) + // - first felt contains the 128 least significat bits (LSB) + // - second felt contains the 128 most significat bits (MSG) + let ph_chunk1_index: usize = da_elements_end_index.wrapping_add(1); + let ph_chunk2_index: usize = ph_chunk1_index.wrapping_add(1); + let mut payload_hash = [0; 32]; + let lsb: [u8; 32] = event + .data + .get(ph_chunk1_index) + .ok_or(ContractCallError::InvalidEvent( + "payload_hash chunk 1 out of range".to_owned(), + ))? + .to_bytes_be(); + let msb: [u8; 32] = event + .data + .get(ph_chunk2_index) + .ok_or(ContractCallError::InvalidEvent( + "payload_hash chunk 2 out of range".to_owned(), + ))? + .to_bytes_be(); + + // most significat bits, go before least significant bits for u256 construction + // check - https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_in_u256_values + payload_hash[..16].copy_from_slice(&msb[16..]); + payload_hash[16..].copy_from_slice(&lsb[16..]); + + Ok(ContractCallEvent { + from_contract_addr, + destination_address, + destination_chain, + source_address, + payload_hash: H256::from_slice(&payload_hash), + }) + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use ethers_core::types::H256; + use starknet_core::types::{Felt, FromStrError}; + use starknet_core::utils::starknet_keccak; + + use super::ContractCallEvent; + use crate::events::contract_call::ContractCallError; + use crate::types::byte_array::ByteArrayError; + + #[test] + fn destination_address_chunks_offset_out_of_range() { + let mut starknet_event = get_dummy_event(); + // longer chunk, which offsets the destination_address byte array out of range + starknet_event.data[1] = + Felt::from_str("0x0000000000000000000000000000000000000000000000000000000000000001") + .unwrap(); + + let event = ContractCallEvent::try_from(starknet_event).unwrap_err(); + assert!(matches!( + event, + ContractCallError::ByteArray(ByteArrayError::ParsingFelt(_)) + )); + } + + #[test] + fn destination_address_chunks_count_too_long() { + let mut starknet_event = get_dummy_event(); + // too long for u32 + starknet_event.data[1] = Felt::MAX; + + let event = ContractCallEvent::try_from(starknet_event).unwrap_err(); + assert!(matches!(event, ContractCallError::TryFromConversion(_))); + } + + #[test] + fn invalid_dest_chain() { + let mut starknet_event = get_dummy_event(); + // too long for Cairo long string too long + starknet_event.keys[1] = Felt::MAX; + + let event = ContractCallEvent::try_from(starknet_event).unwrap_err(); + assert!(matches!(event, ContractCallError::Cairo(_))); + } + + #[test] + fn valid_call_contract_event() { + // the payload is the word "hello" + let starknet_event = get_dummy_event(); + let event = ContractCallEvent::try_from(starknet_event).unwrap(); + + assert_eq!( + event, + ContractCallEvent { + from_contract_addr: String::from( + "0x035410be6f4bf3f67f7c1bb4a93119d9d410b2f981bfafbf5dbbf5d37ae7439e" + ), + destination_address: String::from("hello"), + destination_chain: String::from("destination_chain"), + source_address: Felt::from_str( + "0x00b3ff441a68610b30fd5e2abbf3a1548eb6ba6f3559f2862bf2dc757e5828ca" + ) + .unwrap(), + payload_hash: H256::from_slice(&[ + 28, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, 86, + 217, 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, 200 + ]) + } + ); + } + + fn get_dummy_event() -> starknet_core::types::Event { + // "hello" as payload + // "hello" as destination address + // "some_contract_address" as source address + // "destination_chain" as destination_chain + let event_data: Result, FromStrError> = vec![ + "0xb3ff441a68610b30fd5e2abbf3a1548eb6ba6f3559f2862bf2dc757e5828ca", // the caller addr + "0x0000000000000000000000000000000000000000000000000000000000000000", // 0 data + "0x00000000000000000000000000000000000000000000000000000068656c6c6f", // "hello" + "0x0000000000000000000000000000000000000000000000000000000000000005", // 5 bytes + "0x0000000000000000000000000000000056d9517b9c948127319a09a7a36deac8", // keccak256(hello) + "0x000000000000000000000000000000001c8aff950685c2ed4bc3174f3472287b", + "0x0000000000000000000000000000000000000000000000000000000000000005", // 5 bytes + "0x0000000000000000000000000000000000000000000000000000000000000068", // h + "0x0000000000000000000000000000000000000000000000000000000000000065", // e + "0x000000000000000000000000000000000000000000000000000000000000006c", // l + "0x000000000000000000000000000000000000000000000000000000000000006c", // l + "0x000000000000000000000000000000000000000000000000000000000000006f", // o + ] + .into_iter() + .map(Felt::from_str) + .collect(); + starknet_core::types::Event { + // I think it's a pedersen hash in actuallity, but for the tests I think it's ok + from_address: starknet_keccak("some_contract_address".as_bytes()), + keys: vec![ + starknet_keccak("ContractCall".as_bytes()), + // destination chain + Felt::from_str( + "0x00000000000000000000000000000064657374696e6174696f6e5f636861696e", + ) + .unwrap(), + ], + data: event_data.unwrap(), + } + } +} diff --git a/packages/starknet-types/src/events/signers_rotated.rs b/packages/starknet-types/src/events/signers_rotated.rs new file mode 100644 index 000000000..f022c14c3 --- /dev/null +++ b/packages/starknet-types/src/events/signers_rotated.rs @@ -0,0 +1,371 @@ +use starknet_core::types::{Event, Felt}; +use thiserror::Error; + +/// An error, representing failure to convert/parse a starknet event +/// to a SignersRotated event. +#[derive(Error, Debug)] +pub enum SignersRotatedErrors { + /// Error returned when a required signers hash is missing from a + /// transaction. + #[error("missing signers hash for transaction")] + MissingSignersHash, + + /// Error returned when payload data cannot be parsed correctly. + #[error("failed to parse payload data, error: {0}")] + FailedToParsePayloadData(String), + + /// Error returned when the payload data is missing. + #[error("missing payload data for transaction")] + MissingPayloadData, + + /// Error returned when the epoch number in a transaction is invalid or + /// unexpected. + #[error("incorrect epoch for transaction")] + IncorrectEpoch, + + /// Error returned when the first key doesn't correspod to the + /// SignersRotated event. + #[error("not a SignersRotated event")] + InvalidEvent, + + /// Error returned when the threshold in a transaction is invalid or + /// unexpected. + #[error("incorrect threshold for transaction")] + IncorrectThreshold, + + /// Error returned when the nonce in a transaction is missing. + #[error("missing nonce for transaction")] + MissingNonce, + + /// Error returned when the keys in a transaction are missing. + #[error("missing keys for transaction")] + MissingKeys, +} + +/// Represents a weighted signer +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Signer { + /// The address of the signer + pub signer: String, + /// The weight (voting power) of this signer + pub weight: u128, +} + +/// Represents a set of signers +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WeightedSigners { + pub signers: Vec, + pub threshold: u128, + pub nonce: [u8; 32], +} + +/// Represents a Starknet SignersRotated event +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SignersRotatedEvent { + /// The address of the sender + pub from_address: String, + /// The epoch number when this rotation occurred + pub epoch: u64, + /// The hash of the new signers + pub signers_hash: [u8; 32], + /// The new set of weighted signers with their voting power + pub signers: WeightedSigners, +} + +impl TryFrom for SignersRotatedEvent { + type Error = SignersRotatedErrors; + + /// Attempts to convert a Starknet event to a SignersRotated event + /// + /// # Arguments + /// + /// * `event` - The Starknet event to convert + /// + /// # Returns + /// + /// * `Ok(SignersRotated)` - Successfully converted event containing: + /// * `epoch` - The epoch number when rotation occurred + /// * `signers_hash` - Hash of the new signers (32 bytes) + /// * `signers` - New set of weighted signers with: + /// * List of signers with their addresses and weights + /// * Threshold for required voting power + /// * Nonce value (32 bytes) + /// + /// # Errors + /// + /// Returns a `SignersRotatedErrors` if: + /// * Event data or keys are empty + /// * Failed to parse epoch number + /// * Missing or invalid signers hash + /// * Failed to parse signers array length + /// * Failed to parse signer addresses or weights + /// * Missing or invalid threshold + /// * Missing or invalid nonce + fn try_from(event: Event) -> Result { + if event.data.is_empty() { + return Err(SignersRotatedErrors::MissingPayloadData); + } + if event.keys.is_empty() { + return Err(SignersRotatedErrors::MissingKeys); + } + + let from_address = event.from_address.to_hex_string(); + + // it starts at 2 because 0 is the selector and 1 is the from_address + let epoch_index = 2; + // INFO: there might be better way to convert to u64 + let epoch = event + .keys + .get(epoch_index) + .ok_or(SignersRotatedErrors::IncorrectEpoch)? + .to_string() + .parse::() + .map_err(|_| SignersRotatedErrors::IncorrectEpoch)?; + + // Construct signers hash + let mut signers_hash = [0_u8; 32]; + let lsb = event + .keys + .get(epoch_index + 1) + .map(Felt::to_bytes_be) + .ok_or(SignersRotatedErrors::MissingSignersHash)?; + let msb = event + .keys + .get(epoch_index + 2) + .map(Felt::to_bytes_be) + .ok_or(SignersRotatedErrors::MissingSignersHash)?; + signers_hash[..16].copy_from_slice(&msb[16..]); + signers_hash[16..].copy_from_slice(&lsb[16..]); + + // Parse signers array from event data + let mut buff_signers = vec![]; + + let signers_index = 0; + let signers_len = event.data[signers_index] + .to_string() + .parse::() + .map_err(|_| { + SignersRotatedErrors::FailedToParsePayloadData( + "failed to parse signers length".to_string(), + ) + })?; + let signers_end_index = signers_index.saturating_add(signers_len.saturating_mul(2)); + + // Parse signers and weights + for i in 0..signers_len { + let signer_index = signers_index + .saturating_add(1) + .saturating_add(i.saturating_mul(2)); + let weight_index = signer_index.saturating_add(1); + + // Get signer address as bytes + let signer = event.data[signer_index].to_hex_string(); + + // Parse weight + let weight = event.data[weight_index] + .to_string() + .parse::() + .map_err(|_| { + SignersRotatedErrors::FailedToParsePayloadData( + "failed to parse signer weight".to_string(), + ) + })?; + + buff_signers.push(Signer { signer, weight }); + } + + // Parse threshold + let threshold = event + .data + .get(signers_end_index) + .ok_or(SignersRotatedErrors::IncorrectThreshold)? + .to_string() + .parse::() + .map_err(|_| SignersRotatedErrors::IncorrectThreshold)?; + + // Parse nonce + let mut nonce = [0_u8; 32]; + let lsb = event + .data + .get(event.data.len().saturating_sub(2)) + .map(Felt::to_bytes_be) + .ok_or(SignersRotatedErrors::MissingNonce)?; + let msb = event + .data + .get(event.data.len().saturating_sub(1)) + .map(Felt::to_bytes_be) + .ok_or(SignersRotatedErrors::MissingNonce)?; + nonce[16..].copy_from_slice(&lsb[16..]); + nonce[..16].copy_from_slice(&msb[16..]); + + Ok(SignersRotatedEvent { + from_address, + epoch, + signers_hash, + signers: WeightedSigners { + signers: buff_signers, + threshold, + nonce, + }, + }) + } +} + +#[cfg(test)] +mod tests { + // use futures::stream::{FuturesUnordered, StreamExt}; + use starknet_core::types::{EmittedEvent, Felt}; + + use super::*; + + async fn get_valid_event() -> (Vec, Vec, Felt, Felt) { + let keys_data: Vec = vec![ + Felt::from_hex_unchecked( + "0x01815547484542c49542242a23bc0a1b762af99232f38c0417050825aea8fc93", + ), + Felt::from_hex_unchecked( + "0x0268929df65ee595bb8592323f981351efdc467d564effc6d2e54d2e666e43ca", + ), + Felt::from_hex_unchecked("0x01"), + Felt::from_hex_unchecked("0xd4203fe143363253c89a27a26a6cb81f"), + Felt::from_hex_unchecked("0xe23e7704d24f646e5e362c61407a69d2"), + ]; + + let event_data: Vec = vec![ + Felt::from_hex_unchecked("0x01"), + Felt::from_hex_unchecked("0x3ec7d572a0fe479768ac46355651f22a982b99cc"), + Felt::from_hex_unchecked("0x01"), + Felt::from_hex_unchecked("0x01"), + Felt::from_hex_unchecked("0x2fe49d"), + Felt::from_hex_unchecked("0x00"), + ]; + ( + keys_data, + event_data, + // sender_address + Felt::from_hex_unchecked( + "0x0282b4492e08d8b6bbec8dfe7412e42e897eef9c080c5b97be1537433e583bdc", + ), + // tx_hash + Felt::from_hex_unchecked( + "0x04663231715b17dd58cd08e63d6b31d2c86b158d4730da9a1b75ca2452c9910c", + ), + ) + } + + /// Generate a set of data with random modifications + async fn get_malformed_event() -> (Vec, Vec, Felt, Felt) { + let (mut keys_data, mut event_data, sender_address, tx_hash) = get_valid_event().await; + // Randomly remove an element from either vector + match rand::random::() { + true if !keys_data.is_empty() => { + let random_index = rand::random::() % keys_data.len(); + keys_data.remove(random_index); + } + false if !event_data.is_empty() => { + let random_index = rand::random::() % event_data.len(); + event_data.remove(random_index); + } + _ => {} + } + + // Randomly corrupt data values + if rand::random::() { + if let Some(elem) = keys_data.first_mut() { + *elem = Felt::from_hex_unchecked("0xdeadbeef"); + } + } + if rand::random::() { + if let Some(elem) = event_data.first_mut() { + *elem = Felt::from_hex_unchecked("0xcafebabe"); + } + } + + (keys_data, event_data, sender_address, tx_hash) + } + + #[tokio::test] + async fn test_try_from_event_happy_scenario() { + let (keys_data, event_data, sender_address, _tx_hash) = get_valid_event().await; + + assert!(SignersRotatedEvent::try_from(Event { + from_address: sender_address, + keys: keys_data, + data: event_data, + }) + .is_ok()); + } + + #[tokio::test] + async fn test_try_from_empty_event() { + let (_, _, sender_address, _tx_hash) = get_valid_event().await; + let result = SignersRotatedEvent::try_from(Event { + data: vec![], + from_address: sender_address, + keys: vec![], + }); + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_try_from_event_missing_data() { + let (keys_data, _, sender_address, _tx_hash) = get_valid_event().await; + let event = SignersRotatedEvent::try_from(Event { + data: vec![], + from_address: sender_address, + keys: keys_data, + }); + + assert!(event.is_err()); + assert!(matches!( + event, + Err(SignersRotatedErrors::MissingPayloadData) + )); + } + + #[tokio::test] + async fn test_try_from_event_missing_keys() { + let (_, event_data, sender_address, _tx_hash) = get_valid_event().await; + let event = SignersRotatedEvent::try_from(Event { + data: event_data, + from_address: sender_address, + keys: vec![], + }); + + assert!(event.is_err()); + assert!(matches!(event, Err(SignersRotatedErrors::MissingKeys))); + } + + #[tokio::test] + async fn test_try_from_event_randomly_malformed_data_x1000() { + // let mut futures = FuturesUnordered::new(); + + for _ in 0..1000 { + // futures.push(async { + let (_, event_data, sender_address, tx_hash) = get_malformed_event().await; + let event = EmittedEvent { + data: event_data, + from_address: sender_address, + keys: vec![], + transaction_hash: tx_hash, + block_hash: None, + block_number: None, + }; + let result = SignersRotatedEvent::try_from(Event { + data: event.data, + from_address: event.from_address, + keys: event.keys, + }); + assert!(result.is_err()); + // }); + } + + // if any conversion succeeded then it should have failed + // while let Some(result) = futures.next().await { + // if !result { + // panic!("expected conversion to fail for malformed event"); + // } + // } + } +} diff --git a/packages/starknet-types/src/lib.rs b/packages/starknet-types/src/lib.rs new file mode 100644 index 000000000..7f2db3c29 --- /dev/null +++ b/packages/starknet-types/src/lib.rs @@ -0,0 +1,3 @@ +pub mod error; +pub mod events; +pub mod types; diff --git a/packages/starknet-types/src/types.rs b/packages/starknet-types/src/types.rs new file mode 100644 index 000000000..90b8423bf --- /dev/null +++ b/packages/starknet-types/src/types.rs @@ -0,0 +1,3 @@ +pub mod array_span; +pub mod byte_array; +pub mod starknet_message; diff --git a/packages/starknet-types/src/types/array_span.rs b/packages/starknet-types/src/types/array_span.rs new file mode 100644 index 000000000..eaf1fc613 --- /dev/null +++ b/packages/starknet-types/src/types/array_span.rs @@ -0,0 +1,187 @@ +use starknet_core::types::Felt; +use thiserror::Error; + +/// Represents Cairo's Array and Span types. +/// Implements `TryFrom>`, which is the way to create it. +/// +/// ## Example usage with the string "hello" +/// +/// ```rust +/// use starknet_types::types::array_span::ArraySpan; +/// use std::str::FromStr; +/// use starknet_core::types::Felt; +/// use starknet_core::types::FromStrError; +/// +/// let data: Result, FromStrError> = vec![ +/// "0x0000000000000000000000000000000000000000000000000000000000000005", +/// "0x0000000000000000000000000000000000000000000000000000000000000068", +/// "0x0000000000000000000000000000000000000000000000000000000000000065", +/// "0x000000000000000000000000000000000000000000000000000000000000006c", +/// "0x000000000000000000000000000000000000000000000000000000000000006c", +/// "0x000000000000000000000000000000000000000000000000000000000000006f", +/// ] +/// .into_iter() +/// .map(Felt::from_str) +/// .collect(); +/// +/// let array_span = ArraySpan::::try_from(data.unwrap()).unwrap(); +/// assert_eq!(array_span.data, vec![104, 101, 108, 108, 111]); +/// assert_eq!(String::from_utf8(array_span.data).unwrap(), "hello"); +/// ``` +/// +/// For more info: +/// https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_of_byte_arrays +#[derive(Debug)] +pub struct ArraySpan { + pub data: Vec, +} + +#[derive(Error, Debug)] +pub enum ArraySpanError { + #[error("Invalid array/span length")] + InvalidLength, + #[error("Failed to parse felt - {0}")] + ParsingFelt(String), +} + +impl TryFrom> for ArraySpan { + type Error = ArraySpanError; + + fn try_from(data: Vec) -> Result { + // First element is always the array length, which is a felt (so u8 is enough) + let arr_length = + u8::try_from(data[0]).map_err(|e| ArraySpanError::ParsingFelt(e.to_string()))?; + + // -1 because we have to offset the first element (the length itself) + let arr_length_usize = usize::from(arr_length); + if arr_length_usize != data.len().wrapping_sub(1) { + return Err(ArraySpanError::InvalidLength); + } + + let bytes: Result, ArraySpanError> = data + .get(1..) + .ok_or(ArraySpanError::InvalidLength)? + .iter() + .copied() + .map(|data| u8::try_from(data).map_err(|e| ArraySpanError::ParsingFelt(e.to_string()))) + .collect(); + + Ok(ArraySpan { data: bytes? }) + } +} + +#[cfg(test)] +mod array_span_tests { + use std::str::FromStr; + + use starknet_core::types::{Felt, FromStrError}; + + use super::ArraySpan; + + #[test] + fn try_from_valid_zeros() { + // the string "hello", but Felt is bigger than u8::max + let data = vec![Felt::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000000", + ) + .unwrap()]; + + let array_span = ArraySpan::::try_from(data).unwrap(); + assert_eq!(array_span.data, Vec::::new()); + } + + #[test] + fn try_from_failed_to_parse_element_to_u8() { + // the string "hello", but Felt is bigger than u8::max + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000005", + "0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + "0x0000000000000000000000000000000000000000000000000000000000000065", + "0x000000000000000000000000000000000000000000000000000000000000006c", + "0x000000000000000000000000000000000000000000000000000000000000006c", + "0x000000000000000000000000000000000000000000000000000000000000006f", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let array_span = ArraySpan::::try_from(data.unwrap()); + assert!(array_span.is_err()); + } + + #[test] + fn try_from_failed_to_parse_elements_length_to_u32() { + // the string "hello", but element count is bigger than u32::max + let data: Result, FromStrError> = vec![ + "0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + "0x0000000000000000000000000000000000000000000000000000000000000068", + "0x0000000000000000000000000000000000000000000000000000000000000065", + "0x000000000000000000000000000000000000000000000000000000000000006c", + "0x000000000000000000000000000000000000000000000000000000000000006c", + "0x000000000000000000000000000000000000000000000000000000000000006f", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let array_span = ArraySpan::::try_from(data.unwrap()); + assert!(array_span.is_err()); + } + + #[test] + fn try_from_invalid_number_of_elements() { + // the string "hello", but with only 4 bytes + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000005", + "0x0000000000000000000000000000000000000000000000000000000000000068", + "0x0000000000000000000000000000000000000000000000000000000000000065", + "0x000000000000000000000000000000000000000000000000000000000000006c", + "0x000000000000000000000000000000000000000000000000000000000000006c", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let array_span = ArraySpan::::try_from(data.unwrap()); + assert!(array_span.is_err()); + } + + #[test] + fn try_from_invalid_declared_length() { + // the string "hello", with correct number of bytes, but only 4 declared, + // instead of 5 + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000004", + "0x0000000000000000000000000000000000000000000000000000000000000068", + "0x0000000000000000000000000000000000000000000000000000000000000065", + "0x000000000000000000000000000000000000000000000000000000000000006c", + "0x000000000000000000000000000000000000000000000000000000000000006c", + "0x000000000000000000000000000000000000000000000000000000000000006f", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let array_span = ArraySpan::::try_from(data.unwrap()); + assert!(array_span.is_err()); + } + + #[test] + fn try_from_valid() { + // the string "hello" + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000005", + "0x0000000000000000000000000000000000000000000000000000000000000068", + "0x0000000000000000000000000000000000000000000000000000000000000065", + "0x000000000000000000000000000000000000000000000000000000000000006c", + "0x000000000000000000000000000000000000000000000000000000000000006c", + "0x000000000000000000000000000000000000000000000000000000000000006f", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let array_span = ArraySpan::::try_from(data.unwrap()).unwrap(); + assert_eq!(array_span.data, vec![104, 101, 108, 108, 111]); + } +} diff --git a/packages/starknet-types/src/types/byte_array.rs b/packages/starknet-types/src/types/byte_array.rs new file mode 100644 index 000000000..e9bc10dcf --- /dev/null +++ b/packages/starknet-types/src/types/byte_array.rs @@ -0,0 +1,444 @@ +use itertools::FoldWhile::{Continue, Done}; +use itertools::Itertools; +use starknet_core::types::Felt; +use starknet_core::utils::parse_cairo_short_string; +use thiserror::Error; + +/// Represents Cairo's ByteArray type. +/// Implements `TryFrom>`, which is the way to create it. +/// +/// ## Example usage with the string "hello" +/// +/// ```rust +/// use starknet_types::types::byte_array::ByteArray; +/// use std::str::FromStr; +/// use starknet_core::types::Felt; +/// use starknet_core::types::FromStrError; +/// +/// let data: Result, FromStrError> = vec![ +/// "0x0000000000000000000000000000000000000000000000000000000000000000", +/// "0x00000000000000000000000000000000000000000000000000000068656c6c6f", +/// "0x0000000000000000000000000000000000000000000000000000000000000005", +/// ] +/// .into_iter() +/// .map(Felt::from_str) +/// .collect(); +/// +/// let byte_array = ByteArray::try_from(data.unwrap()); +/// assert!(byte_array.is_ok()); +/// ``` +/// +/// For more info: +/// https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_of_byte_arrays +#[derive(Debug, Default)] +pub struct ByteArray { + /// The data byte array. Contains 31-byte chunks of the byte array. + data: Vec, + /// The bytes that remain after filling the data array with full 31-byte + /// chunks + pending_word: Felt, + /// The byte count of the pending_word + pending_word_length: u8, // can't be more than 30 bytes +} + +#[derive(Error, Debug)] +pub enum ByteArrayError { + #[error("Failed to fetch element from byte array at index")] + OutOfBound, + #[error("Invalid byte array - {0}")] + InvalidByteArray(String), + #[error("Failed to convert felt - {0}")] + ParsingFelt(String), + #[error("Failed to convert the byte array into a string")] + ToString, +} + +impl TryFrom> for ByteArray { + type Error = ByteArrayError; + + fn try_from(data: Vec) -> Result { + // pending word is always the next to last element + let pending_word_index = data.len().wrapping_sub(2); + let last_element_index = data.len().wrapping_sub(1); + + let mut byte_array = ByteArray { + ..Default::default() + }; + + if data.len() < 3 { + return Err(ByteArrayError::InvalidByteArray( + "vec should have minimum 3 elements".to_owned(), + )); + } + + // word count is always the first element, which is a felt (so u8 is enough) + let word_count = + u8::try_from(data[0]).map_err(|e| ByteArrayError::ParsingFelt(e.to_string()))?; + + // vec element count should be whatever the word count is + an offset of 3 + // the 3 stands for the minimum 3 elements: + // - word count + // - pending_word + // - pendint_word_length + let word_count_usize = usize::from(word_count.wrapping_add(3)); + if word_count_usize != data.len() { + return Err(ByteArrayError::InvalidByteArray( + "pre-defined count doesn't match actual 31byte element count".to_owned(), + )); + } + + // pending word byte count is always the last element + let pending_word_length_felt = data + .get(last_element_index) + .ok_or(ByteArrayError::OutOfBound)?; + let pending_word_length = u8::try_from(*pending_word_length_felt) + .map_err(|e| ByteArrayError::ParsingFelt(e.to_string()))?; + byte_array.pending_word_length = pending_word_length; + + let pending_word = data + .get(pending_word_index) + .ok_or(ByteArrayError::OutOfBound)?; + byte_array.pending_word = *pending_word; + + // count bytes, excluding leading zeros + let non_zero_pw_length = pending_word + .to_bytes_be() + .iter() + .fold_while(32, |acc: u8, n| { + if *n == 0 { + Continue(acc.saturating_sub(1)) + } else { + Done(acc) + } + }) + .into_inner(); + + if pending_word_length != non_zero_pw_length { + return Err(ByteArrayError::InvalidByteArray( + "pending_word length doesn't match it's defined length".to_owned(), + )); + } + + if word_count > 0 { + let byte_array_data = data + .get(1..pending_word_index) + .ok_or(ByteArrayError::OutOfBound)? + .to_vec(); + + byte_array.data = byte_array_data; + } + + Ok(byte_array) + } +} + +impl ByteArray { + /// Takes the ByteArray struct and tries to parse it as a single string + /// + /// ## Example usage with the string "hello" + /// + /// ```rust + /// use starknet_types::types::byte_array::ByteArray; + /// use std::str::FromStr; + /// use starknet_core::types::Felt; + /// use starknet_core::types::FromStrError; + /// + /// let data: Result, FromStrError> = vec![ + /// "0x0000000000000000000000000000000000000000000000000000000000000000", + /// "0x00000000000000000000000000000000000000000000000000000068656c6c6f", + /// "0x0000000000000000000000000000000000000000000000000000000000000005", + /// ] + /// .into_iter() + /// .map(Felt::from_str) + /// .collect(); + /// + /// let byte_array = ByteArray::try_from(data.unwrap()).unwrap(); + /// assert_eq!("hello", byte_array.try_to_string().unwrap()); + /// ``` + /// + /// Additional documentation you can find here: + /// https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_of_byte_arrays + pub fn try_to_string(&self) -> Result { + match self + .data + .iter() + .chain(std::iter::once(&self.pending_word)) + .map(parse_cairo_short_string) + .collect::>() + { + Ok(s) => Ok(s), + Err(_) => Err(ByteArrayError::ToString), + } + } +} + +#[cfg(test)] +mod byte_array_tests { + use std::str::FromStr; + + use starknet_core::types::{Felt, FromStrError}; + + use super::ByteArray; + + #[test] + fn byte_array_parse_fail_wrong_pending_word_length() { + // Example for a small string (fits in a single felt) taken from here: + // https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_of_byte_arrays + // + // So this is the string "hello" + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000000", + "0x0000000000000000000000000000000000000000000000000000068656c6c6f", + // Should be of length 5 bytes, but we put 6 bytes, in order to fail + // the parsing + "0x0000000000000000000000000000000000000000000000000000000000000020", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let byte_array = ByteArray::try_from(data.unwrap()); + assert!(byte_array.is_err()); + } + + #[test] + fn byte_array_to_string_error() { + // Example for a small string (fits in a single felt) taken from here: + // https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_of_byte_arrays + // + // So this is the string "hello" + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000000", + // Note the 01 in the beginning. This is what causes the parse + // function to error. + "0x01000000000000000000000000000000000000000000000000000068656c6c6f", + // 32(0x20) bytes long pending_word + "0x0000000000000000000000000000000000000000000000000000000000000020", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let byte_array = ByteArray::try_from(data.unwrap()).unwrap(); + assert!(byte_array.try_to_string().is_err()); + } + + #[test] + fn byte_array_single_pending_word_only_to_string_valid() { + // Example for a small string (fits in a single felt) taken from here: + // https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_of_byte_arrays + // + // So this is the string "hello" + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000000", + "0x00000000000000000000000000000000000000000000000000000068656c6c6f", + "0x0000000000000000000000000000000000000000000000000000000000000005", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let byte_array = ByteArray::try_from(data.unwrap()).unwrap(); + assert_eq!("hello", byte_array.try_to_string().unwrap()); + } + + #[test] + fn byte_array_to_long_string_valid() { + // Example for a long string (doesn't fit in a single felt) taken from here: + // https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_of_byte_arrays + // + // So this is the string "Long long string, a lot more than 31 characters that + // wouldn't even fit in two felts, so we'll have at least two felts and a + // pending word." + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000004", + "0x00004c6f6e67206c6f6e6720737472696e672c2061206c6f74206d6f72652074", + "0x000068616e2033312063686172616374657273207468617420776f756c646e27", + "0x000074206576656e2066697420696e2074776f2066656c74732c20736f207765", + "0x0000276c6c2068617665206174206c656173742074776f2066656c747320616e", + "0x0000000000000000000000000000006420612070656e64696e6720776f72642e", + "0x0000000000000000000000000000000000000000000000000000000000000011", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let byte_array = ByteArray::try_from(data.unwrap()).unwrap(); + assert_eq!("Long long string, a lot more than 31 characters that wouldn't even fit in two felts, so we'll have at least two felts and a pending word.", byte_array.try_to_string().unwrap()); + } + + #[test] + fn try_from_vec_count_less_then_3() { + let data: Result, FromStrError> = + vec!["0x0000000000000000000000000000000000000000000000000000000000000005"] + .into_iter() + .map(Felt::from_str) + .collect(); + + let byte_array_err = ByteArray::try_from(data.unwrap()); + assert!(byte_array_err.is_err()); + } + + #[test] + fn try_from_non_u32_word_count() { + let data: Result, FromStrError> = vec![ + // should be 0, because the message is short + // enough to fit in a single Felt + "0x00000000000000000000000000000000000000000000000000000068656c6c6f", + "0x00000000000000000000000000000000000000000000000000000068656c6c6f", + "0x0000000000000000000000000000000000000000000000000000000000000005", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let byte_array_err = ByteArray::try_from(data.unwrap()); + assert!(byte_array_err.is_err()); + } + #[test] + fn try_from_invalid_byte_array_element_count() { + let data: Result, FromStrError> = vec![ + // should be 0, because the message is short + // enough to fit in a single Felt + "0x0000000000000000000000000000000000000000000000000000000000000005", + "0x00000000000000000000000000000000000000000000000000000068656c6c6f", + "0x0000000000000000000000000000000000000000000000000000000000000005", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let byte_array_err = ByteArray::try_from(data.unwrap()); + assert!(byte_array_err.is_err()); + } + + #[test] + fn try_from_non_u8_pending_word_length() { + // Example for a small string (fits in a single felt) taken from here: + // https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_of_byte_arrays + // + // So this is the string "hello" + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000000", + "0x00000000000000000000000000000000000000000000000000000068656c6c6f", + "0x00000000000000000000000000000000000000000000000000000068656c6c6f", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let byte_array = ByteArray::try_from(data.unwrap()); + assert!(byte_array.is_err()); + } + + #[test] + fn try_from_valid_only_pending_word() { + // Example for a small string (fits in a single felt) taken from here: + // https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_of_byte_arrays + // + // So this is the string "hello" + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000000", + "0x00000000000000000000000000000000000000000000000000000068656c6c6f", + "0x0000000000000000000000000000000000000000000000000000000000000005", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let byte_array = ByteArray::try_from(data.unwrap()).unwrap(); + + assert_eq!(byte_array.data, vec![]); + assert_eq!( + byte_array.pending_word, + Felt::from_str("0x00000000000000000000000000000000000000000000000000000068656c6c6f",) + .unwrap() + ); + assert_eq!(byte_array.pending_word_length, 5); + } + + #[test] + fn try_from_valid_one_big_string_split_in_multiple_data_elements() { + // Example for a long string (doesn't fit in a single felt) taken from here: + // https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_of_byte_arrays + // + // So this is the string "Long long string, a lot more than 31 characters that + // wouldn't even fit in two felts, so we'll have at least two felts and a + // pending word." + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000004", + "0x00004c6f6e67206c6f6e6720737472696e672c2061206c6f74206d6f72652074", + "0x000068616e2033312063686172616374657273207468617420776f756c646e27", + "0x000074206576656e2066697420696e2074776f2066656c74732c20736f207765", + "0x0000276c6c2068617665206174206c656173742074776f2066656c747320616e", + "0x0000000000000000000000000000006420612070656e64696e6720776f72642e", + "0x0000000000000000000000000000000000000000000000000000000000000011", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let byte_array = ByteArray::try_from(data.unwrap()).unwrap(); + + assert_eq!( + byte_array.data, + vec![ + Felt::from_str( + "0x00004c6f6e67206c6f6e6720737472696e672c2061206c6f74206d6f72652074", + ) + .unwrap(), + Felt::from_str( + "0x000068616e2033312063686172616374657273207468617420776f756c646e27", + ) + .unwrap(), + Felt::from_str( + "0x000074206576656e2066697420696e2074776f2066656c74732c20736f207765", + ) + .unwrap(), + Felt::from_str( + "0x0000276c6c2068617665206174206c656173742074776f2066656c747320616e", + ) + .unwrap() + ] + ); + assert_eq!( + byte_array.pending_word, + Felt::from_str("0x0000000000000000000000000000006420612070656e64696e6720776f72642e",) + .unwrap() + ); + assert_eq!(byte_array.pending_word_length, 17); + } + + #[test] + fn try_from_valid_one_very_big_string() { + // Example for a long string (doesn't fit in a single felt) taken from here: + // https://docs.starknet.io/documentation/architecture_and_concepts/Smart_Contracts/serialization_of_Cairo_types/#serialization_of_byte_arrays + // + // So this is the string "Long string, more than 31 characters." + let data: Result, FromStrError> = vec![ + "0x0000000000000000000000000000000000000000000000000000000000000001", + "0x004c6f6e6720737472696e672c206d6f7265207468616e203331206368617261", + "0x000000000000000000000000000000000000000000000000000063746572732e", + "0x0000000000000000000000000000000000000000000000000000000000000006", + ] + .into_iter() + .map(Felt::from_str) + .collect(); + + let byte_array = ByteArray::try_from(data.unwrap()).unwrap(); + + assert_eq!( + byte_array.data, + vec![Felt::from_str( + "0x004c6f6e6720737472696e672c206d6f7265207468616e203331206368617261", + ) + .unwrap()] + ); + assert_eq!( + byte_array.pending_word, + Felt::from_str("0x000000000000000000000000000000000000000000000000000063746572732e",) + .unwrap() + ); + assert_eq!(byte_array.pending_word_length, 6); + } +} diff --git a/packages/starknet-types/src/types/starknet_message.rs b/packages/starknet-types/src/types/starknet_message.rs new file mode 100644 index 000000000..33ada5400 --- /dev/null +++ b/packages/starknet-types/src/types/starknet_message.rs @@ -0,0 +1,262 @@ +use std::str::FromStr; + +use error_stack::{Report, ResultExt}; +use ethers_core::abi::{ + AbiDecode, AbiError, AbiType, Detokenize, InvalidOutputType, ParamType, Token, Tokenizable, +}; +use ethers_core::types::U256; +use router_api::Message as RouterMessage; +use starknet_checked_felt::CheckedFelt; + +use crate::error::Error; + +/// A message that is encoded in the prover and later sent to the Starknet gateway. +#[derive(Clone, Debug, PartialEq)] +pub struct StarknetMessage { + pub source_chain: String, + pub message_id: String, + pub source_address: String, + pub contract_address: CheckedFelt, + pub payload_hash: U256, +} + +impl TryFrom<&RouterMessage> for StarknetMessage { + type Error = Report; + + fn try_from(msg: &RouterMessage) -> Result { + let contract_address = CheckedFelt::from_str(msg.destination_address.as_str()) + .change_context(Error::InvalidAddress)?; + + Ok(StarknetMessage { + source_chain: msg.cc_id.source_chain.to_string(), + message_id: msg.cc_id.message_id.to_string(), + source_address: msg.source_address.to_string(), + contract_address, + payload_hash: U256::from(msg.payload_hash), + }) + } +} + +impl AbiType for StarknetMessage { + fn param_type() -> ParamType { + ParamType::Tuple(vec![ + ethers_core::abi::ParamType::String, + ethers_core::abi::ParamType::String, + ethers_core::abi::ParamType::String, + ethers_core::abi::ParamType::FixedBytes(32usize), + ::param_type(), + ]) + } +} + +impl AbiDecode for StarknetMessage { + fn decode(bytes: impl AsRef<[u8]>) -> Result { + let tokens = ethers_core::abi::decode(&[Self::param_type()], bytes.as_ref())?; + Ok(::from_tokens(tokens)?) + } +} + +impl Tokenizable for StarknetMessage { + fn from_token(token: Token) -> Result + where + Self: Sized, + { + if let Token::Tuple(tokens) = token { + if tokens.len() != 5 { + return Err(InvalidOutputType( + "failed to read tokens: starknet message should have 5 tokens".to_string(), + )); + } + + if let ( + Token::String(source_chain), + Token::String(message_id), + Token::String(source_address), + Token::FixedBytes(contract_address), + Token::Uint(payload_hash), + ) = ( + tokens[0].clone(), + tokens[1].clone(), + tokens[2].clone(), + tokens[3].clone(), + tokens[4].clone(), + ) { + let contract_address_felt: CheckedFelt = + CheckedFelt::try_from(contract_address.as_slice()).map_err(|e| { + InvalidOutputType( + format!( + "failed to convert contract_address bytes to field element (felt): {}", + e + ) + .to_string(), + ) + })?; + + return Ok(StarknetMessage { + source_chain, + message_id, + source_address, + contract_address: contract_address_felt, + payload_hash, + }); + } + } + + Err(InvalidOutputType( + "failed to convert tokens to StarknetMessage".to_string(), + )) + } + + fn into_token(self) -> Token { + let contract_address_bytes = self.contract_address.to_bytes_be().to_vec(); + + Token::Tuple(vec![ + Token::String(self.source_chain), + Token::String(self.message_id), + Token::String(self.source_address), + Token::FixedBytes(contract_address_bytes), + Token::Uint(self.payload_hash), + ]) + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use ethers_core::abi::{InvalidOutputType, Token, Tokenizable}; + use ethers_core::types::U256; + use starknet_checked_felt::CheckedFelt; + use starknet_core::types::Felt; + + use super::StarknetMessage; + + #[test] + fn starknet_message_from_token_should_error_on_non_tuple() { + // pas something else than a Token::Tuple + let starknet_msg_token = Token::String("not a starknet message".to_string()); + + let result = StarknetMessage::from_token(starknet_msg_token); + + // Tested like this, because InvalidOutputType doesn't implement PartialEq + assert!( + matches!(result, Err(InvalidOutputType(msg)) if msg == "failed to convert tokens to StarknetMessage") + ); + } + + #[test] + fn starknet_message_from_token_should_error_on_failing_felt_conversion() { + // overflow the 31 byte size of a Felt + let starknet_msg_token = Token::Tuple(vec![ + Token::String("starknet".to_string()), + Token::String("some_msg_id".to_string()), + Token::String("some_source_address".to_string()), + Token::FixedBytes(vec![ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + Token::Uint(U256::from(123)), + ]); + + let result = StarknetMessage::from_token(starknet_msg_token); + + // Tested like this, because InvalidOutputType doesn't implement PartialEq + assert!( + matches!(result, Err(InvalidOutputType(msg)) if msg == "failed to convert contract_address bytes to field element (felt): Felt value overflowing the Felt::MAX, value") + ); + } + + #[test] + fn starknet_message_from_token_should_error_on_failing_contract_address_conversion() { + // more than 32 bytes for contract address + let starknet_msg_token = Token::Tuple(vec![ + Token::String("starknet".to_string()), + Token::String("some_msg_id".to_string()), + Token::String("some_source_address".to_string()), + Token::FixedBytes(vec![ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 3, 2, 1, + ]), + Token::Uint(U256::from(123)), + ]); + + let result = StarknetMessage::from_token(starknet_msg_token); + + // Tested like this, because InvalidOutputType doesn't implement PartialEq + assert!( + matches!(result, Err(InvalidOutputType(msg)) if msg == "failed to convert contract_address bytes to field element (felt): Felt value overflowing the Felt::MAX, value") + ); + } + + #[test] + fn starknet_message_from_token_should_error_on_less_tokens() { + // removed last token + let starknet_msg_token = Token::Tuple(vec![ + Token::String("starknet".to_string()), + Token::String("some_msg_id".to_string()), + Token::String("some_source_address".to_string()), + Token::FixedBytes(vec![ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 3, + ]), + ]); + + let result = StarknetMessage::from_token(starknet_msg_token); + + // Tested like this, because InvalidOutputType doesn't implement PartialEq + assert!( + matches!(result, Err(InvalidOutputType(msg)) if msg == "failed to read tokens: starknet message should have 5 tokens") + ); + } + + #[test] + fn starknet_message_from_token_should_be_converted_from_tokens_successfully() { + let starknet_msg_token = Token::Tuple(vec![ + Token::String("starknet".to_string()), + Token::String("some_msg_id".to_string()), + Token::String("some_source_address".to_string()), + Token::FixedBytes(vec![ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 3, + ]), + Token::Uint(U256::from(123)), + ]); + + let expected = StarknetMessage { + source_chain: "starknet".to_string(), + message_id: "some_msg_id".to_string(), + source_address: "some_source_address".to_string(), + contract_address: CheckedFelt::from_str(&Felt::THREE.to_fixed_hex_string()).unwrap(), + payload_hash: U256::from(123), + }; + + assert_eq!( + StarknetMessage::from_token(starknet_msg_token).unwrap(), + expected + ); + } + + #[test] + fn starknet_message_should_convert_to_token() { + let starknet_message = StarknetMessage { + source_chain: "starknet".to_string(), + message_id: "some_msg_id".to_string(), + source_address: "some_source_address".to_string(), + contract_address: CheckedFelt::from_str(&Felt::THREE.to_fixed_hex_string()).unwrap(), + payload_hash: U256::from(123), + }; + + let expected = Token::Tuple(vec![ + Token::String("starknet".to_string()), + Token::String("some_msg_id".to_string()), + Token::String("some_source_address".to_string()), + Token::FixedBytes(vec![ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 3, + ]), + Token::Uint(U256::from(123)), + ]); + + assert_eq!(starknet_message.into_token(), expected); + } +}