Skip to content

Commit

Permalink
refactor(target_chains/starknet): remove Result from reader (#1536)
Browse files Browse the repository at this point in the history
  • Loading branch information
Riateche authored May 2, 2024
1 parent 587a6fa commit 3085997
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 145 deletions.
13 changes: 3 additions & 10 deletions target_chains/starknet/contracts/src/hash.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,9 @@ pub impl HasherImpl of HasherTrait {
// reader.len() < 8
chunk_len = reader.len().try_into().expect(UNEXPECTED_OVERFLOW);
}
match reader.read_num_bytes(chunk_len) {
Result::Ok(value) => {
// chunk_len <= 8 so value must fit in u64.
self.push_to_last(value.try_into().expect(UNEXPECTED_OVERFLOW), chunk_len);
},
Result::Err(err) => {
result = Result::Err(err);
break;
},
}
let value = reader.read_num_bytes(chunk_len);
// chunk_len <= 8 so value must fit in u64.
self.push_to_last(value.try_into().expect(UNEXPECTED_OVERFLOW), chunk_len);
};
result
}
Expand Down
13 changes: 3 additions & 10 deletions target_chains/starknet/contracts/src/merkle_tree.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,13 @@ pub fn read_and_verify_proof(
let mut message_reader = ReaderImpl::new(message.clone());
let mut current_hash = leaf_hash(message_reader.clone()).map_err()?;

let proof_size = reader.read_u8().map_err()?;
let proof_size = reader.read_u8();
let mut i = 0;

let mut result = Result::Ok(());
while i < proof_size {
match reader.read_u160().map_err() {
Result::Ok(sibling_digest) => {
current_hash = node_hash(current_hash, sibling_digest);
},
Result::Err(err) => {
result = Result::Err(err);
break;
},
}
let sibling_digest = reader.read_u160();
current_hash = node_hash(current_hash, sibling_digest);
i += 1;
};
result?;
Expand Down
50 changes: 25 additions & 25 deletions target_chains/starknet/contracts/src/pyth.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -357,30 +357,30 @@ mod pyth {
ref self: ContractState, data: ByteArray
) -> Result<(), UpdatePriceFeedsError> {
let mut reader = ReaderImpl::new(data);
let x = reader.read_u32().map_err()?;
let x = reader.read_u32();
if x != ACCUMULATOR_MAGIC {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
}
if reader.read_u8().map_err()? != MAJOR_VERSION {
if reader.read_u8() != MAJOR_VERSION {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
}
if reader.read_u8().map_err()? < MINIMUM_ALLOWED_MINOR_VERSION {
if reader.read_u8() < MINIMUM_ALLOWED_MINOR_VERSION {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
}

let trailing_header_size = reader.read_u8().map_err()?;
reader.skip(trailing_header_size).map_err()?;
let trailing_header_size = reader.read_u8();
reader.skip(trailing_header_size);

let update_type: Option<UpdateType> = reader.read_u8().map_err()?.try_into();
let update_type: Option<UpdateType> = reader.read_u8().try_into();
match update_type {
Option::Some(v) => match v {
UpdateType::WormholeMerkle => {}
},
Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
};

let wh_proof_size = reader.read_u16().map_err()?;
let wh_proof = reader.read_byte_array(wh_proof_size.into()).map_err()?;
let wh_proof_size = reader.read_u16();
let wh_proof = reader.read_byte_array(wh_proof_size.into());
let wormhole = IWormholeDispatcher { contract_address: self.wormhole_address.read() };
let vm = wormhole.parse_and_verify_vm(wh_proof).map_err()?;

Expand All @@ -392,24 +392,24 @@ mod pyth {
}

let mut payload_reader = ReaderImpl::new(vm.payload);
let x = payload_reader.read_u32().map_err()?;
let x = payload_reader.read_u32();
if x != ACCUMULATOR_WORMHOLE_MAGIC {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
}

let update_type: Option<UpdateType> = payload_reader.read_u8().map_err()?.try_into();
let update_type: Option<UpdateType> = payload_reader.read_u8().try_into();
match update_type {
Option::Some(v) => match v {
UpdateType::WormholeMerkle => {}
},
Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
};

let _slot = payload_reader.read_u64().map_err()?;
let _ring_size = payload_reader.read_u32().map_err()?;
let root_digest = payload_reader.read_u160().map_err()?;
let _slot = payload_reader.read_u64();
let _ring_size = payload_reader.read_u32();
let root_digest = payload_reader.read_u160();

let num_updates = reader.read_u8().map_err()?;
let num_updates = reader.read_u8();

let total_fee = get_total_fee(ref self, num_updates);
let fee_contract = IERC20CamelDispatcher {
Expand Down Expand Up @@ -451,27 +451,27 @@ mod pyth {
fn read_and_verify_message(
ref reader: Reader, root_digest: u256
) -> Result<PriceFeedMessage, UpdatePriceFeedsError> {
let message_size = reader.read_u16().map_err()?;
let message = reader.read_byte_array(message_size.into()).map_err()?;
let message_size = reader.read_u16();
let message = reader.read_byte_array(message_size.into());
read_and_verify_proof(root_digest, @message, ref reader).map_err()?;

let mut message_reader = ReaderImpl::new(message);
let message_type: Option<MessageType> = message_reader.read_u8().map_err()?.try_into();
let message_type: Option<MessageType> = message_reader.read_u8().try_into();
match message_type {
Option::Some(v) => match v {
MessageType::PriceFeed => {}
},
Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
};

let price_id = message_reader.read_u256().map_err()?;
let price = u64_as_i64(message_reader.read_u64().map_err()?);
let conf = message_reader.read_u64().map_err()?;
let expo = u32_as_i32(message_reader.read_u32().map_err()?);
let publish_time = message_reader.read_u64().map_err()?;
let prev_publish_time = message_reader.read_u64().map_err()?;
let ema_price = u64_as_i64(message_reader.read_u64().map_err()?);
let ema_conf = message_reader.read_u64().map_err()?;
let price_id = message_reader.read_u256();
let price = u64_as_i64(message_reader.read_u64());
let conf = message_reader.read_u64();
let expo = u32_as_i32(message_reader.read_u32());
let publish_time = message_reader.read_u64();
let prev_publish_time = message_reader.read_u64();
let ema_price = u64_as_i64(message_reader.read_u64());
let ema_conf = message_reader.read_u64();

let message = PriceFeedMessage {
price_id, price, conf, expo, publish_time, prev_publish_time, ema_price, ema_conf,
Expand Down
132 changes: 49 additions & 83 deletions target_chains/starknet/contracts/src/reader.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -42,112 +42,77 @@ pub impl ReaderImpl of ReaderTrait {
}

/// Reads the specified number of bytes (up to 16) as a big endian unsigned integer.
fn read_num_bytes(ref self: Reader, num_bytes: u8) -> Result<u128, Error> {
fn read_num_bytes(ref self: Reader, num_bytes: u8) -> u128 {
assert!(num_bytes <= 16, "Reader::read_num_bytes: num_bytes is too large");
if num_bytes <= self.num_current_bytes {
let x = self.read_from_current(num_bytes);
return Result::Ok(x);
return self.read_from_current(num_bytes);
}
let num_low_bytes = num_bytes - self.num_current_bytes;
let high = self.current;
self.fetch_next()?;
let low = self.read_num_bytes(num_low_bytes)?;
let value = if num_low_bytes == 16 {
self.fetch_next();
let low = self.read_num_bytes(num_low_bytes);
if num_low_bytes == 16 {
low
} else {
high * one_shift_left_bytes_u128(num_low_bytes) + low
};
Result::Ok(value)
}
}

fn read_u256(ref self: Reader) -> Result<u256, Error> {
let high = self.read_num_bytes(16)?;
let low = self.read_num_bytes(16)?;
fn read_u256(ref self: Reader) -> u256 {
let high = self.read_num_bytes(16);
let low = self.read_num_bytes(16);
let value = u256 { high, low };
Result::Ok(value)
value
}
fn read_u160(ref self: Reader) -> Result<u256, Error> {
let high = self.read_num_bytes(4)?;
let low = self.read_num_bytes(16)?;
let value = u256 { high, low };
Result::Ok(value)
fn read_u160(ref self: Reader) -> u256 {
let high = self.read_num_bytes(4);
let low = self.read_num_bytes(16);
u256 { high, low }
}
fn read_u128(ref self: Reader) -> Result<u128, Error> {
fn read_u128(ref self: Reader) -> u128 {
self.read_num_bytes(16)
}
fn read_u64(ref self: Reader) -> Result<u64, Error> {
let value = self.read_num_bytes(8)?.try_into().expect(UNEXPECTED_OVERFLOW);
Result::Ok(value)
fn read_u64(ref self: Reader) -> u64 {
self.read_num_bytes(8).try_into().expect(UNEXPECTED_OVERFLOW)
}
fn read_u32(ref self: Reader) -> Result<u32, Error> {
let value = self.read_num_bytes(4)?.try_into().expect(UNEXPECTED_OVERFLOW);
Result::Ok(value)
fn read_u32(ref self: Reader) -> u32 {
self.read_num_bytes(4).try_into().expect(UNEXPECTED_OVERFLOW)
}
fn read_u16(ref self: Reader) -> Result<u16, Error> {
let value = self.read_num_bytes(2)?.try_into().expect(UNEXPECTED_OVERFLOW);
Result::Ok(value)
fn read_u16(ref self: Reader) -> u16 {
self.read_num_bytes(2).try_into().expect(UNEXPECTED_OVERFLOW)
}
fn read_u8(ref self: Reader) -> Result<u8, Error> {
let value = self.read_num_bytes(1)?.try_into().expect(UNEXPECTED_OVERFLOW);
Result::Ok(value)
fn read_u8(ref self: Reader) -> u8 {
self.read_num_bytes(1).try_into().expect(UNEXPECTED_OVERFLOW)
}

// TODO: skip without calculating values
fn skip(ref self: Reader, mut num_bytes: u8) -> Result<(), Error> {
let mut result = Result::Ok(());
fn skip(ref self: Reader, mut num_bytes: u8) {
while num_bytes > 0 {
if num_bytes > 16 {
match self.read_num_bytes(16) {
Result::Ok(_) => {},
Result::Err(err) => {
result = Result::Err(err);
break;
}
}
self.read_num_bytes(16);
num_bytes -= 16;
} else {
match self.read_num_bytes(num_bytes) {
Result::Ok(_) => {},
Result::Err(err) => {
result = Result::Err(err);
break;
}
}
break;
self.read_num_bytes(num_bytes);
}
};
result
}
}

/// Reads the specified number of bytes as a new byte array.
fn read_byte_array(ref self: Reader, num_bytes: usize) -> Result<ByteArray, Error> {
fn read_byte_array(ref self: Reader, num_bytes: usize) -> ByteArray {
let mut array: Array<bytes31> = array![];
let mut num_last_bytes = Option::None;
let mut num_last_bytes = 0;
let mut num_remaining_bytes = num_bytes;
loop {
let r = self.read_bytes_iteration(num_remaining_bytes, ref array);
match r {
Result::Ok((
num_read, eof
)) => {
num_remaining_bytes -= num_read;
if eof {
num_last_bytes = Option::Some(Result::Ok(num_read));
break;
}
},
Result::Err(err) => {
num_last_bytes = Option::Some(Result::Err(err));
break;
}
let (num_read, eof) = self.read_bytes_iteration(num_remaining_bytes, ref array);
num_remaining_bytes -= num_read;
if eof {
num_last_bytes = num_read;
break;
}
};
// `num_last_bytes` is always set to Some before break.
let num_last_bytes = num_last_bytes.unwrap()?;
// num_last_bytes < 31
let num_last_bytes = num_last_bytes.try_into().expect(UNEXPECTED_OVERFLOW);
let array = ByteArrayImpl::new(array, num_last_bytes);
Result::Ok(array)
ByteArrayImpl::new(array, num_last_bytes)
}

/// Returns number of remaining bytes to read.
Expand Down Expand Up @@ -179,15 +144,18 @@ impl ReaderPrivateImpl of ReaderPrivateTrait {
/// Replenishes `self.current` and `self.num_current_bytes`.
/// This should only be called when all bytes from `self.current` has been read.
/// Returns `EOF` error if no more data is available.
fn fetch_next(ref self: Reader) -> Result<(), Error> {
fn fetch_next(ref self: Reader) {
match self.next {
Option::Some(next) => {
self.next = Option::None;
self.current = next;
self.num_current_bytes = 16;
},
Option::None => {
let (value, bytes) = self.array.pop_front().ok_or(Error::UnexpectedEndOfInput)?;
let (value, bytes) = self
.array
.pop_front()
.expect(Error::UnexpectedEndOfInput.into());
let value: u256 = value.into();
if bytes > 16 {
self.current = value.high;
Expand All @@ -199,33 +167,31 @@ impl ReaderPrivateImpl of ReaderPrivateTrait {
}
},
}
Result::Ok(())
}

// Moved out from `read_bytes` because we cannot use `return` or `?` within a loop.
fn read_bytes_iteration(
ref self: Reader, num_bytes: usize, ref array: Array<bytes31>
) -> Result<(usize, bool), Error> {
) -> (usize, bool) {
if num_bytes >= 31 {
let high = self.read_num_bytes(15)?;
let low = self.read_num_bytes(16)?;
let high = self.read_num_bytes(15);
let low = self.read_num_bytes(16);
let value: felt252 = u256 { high, low }.try_into().expect(UNEXPECTED_OVERFLOW);
array.append(value.try_into().expect(UNEXPECTED_OVERFLOW));
Result::Ok((31, false))
(31, false)
} else if num_bytes > 16 {
// num_bytes < 31
let high = self
.read_num_bytes((num_bytes - 16).try_into().expect(UNEXPECTED_OVERFLOW))?;
let low = self.read_num_bytes(16)?;
let high = self.read_num_bytes((num_bytes - 16).try_into().expect(UNEXPECTED_OVERFLOW));
let low = self.read_num_bytes(16);
let value: felt252 = u256 { high, low }.try_into().expect(UNEXPECTED_OVERFLOW);
array.append(value.try_into().expect(UNEXPECTED_OVERFLOW));
Result::Ok((num_bytes, true))
(num_bytes, true)
} else {
// bytes < 16
let low = self.read_num_bytes(num_bytes.try_into().expect(UNEXPECTED_OVERFLOW))?;
let low = self.read_num_bytes(num_bytes.try_into().expect(UNEXPECTED_OVERFLOW));
let value: felt252 = low.try_into().expect(UNEXPECTED_OVERFLOW);
array.append(value.try_into().expect(UNEXPECTED_OVERFLOW));
Result::Ok((num_bytes, true))
(num_bytes, true)
}
}
}
Loading

0 comments on commit 3085997

Please sign in to comment.