Skip to content

Commit

Permalink
(WIP) Integrate RequestTracker into the network logic
Browse files Browse the repository at this point in the history
  • Loading branch information
madadam committed Sep 26, 2024
1 parent eeb75ce commit 355aacc
Show file tree
Hide file tree
Showing 11 changed files with 342 additions and 245 deletions.
3 changes: 3 additions & 0 deletions lib/src/block_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ pub(crate) struct TrackerClient {

impl TrackerClient {
/// Returns a stream of offers for required blocks.
#[cfg_attr(not(test), expect(dead_code))]
pub fn offers(&self) -> BlockOffers {
BlockOffers {
shared: self.shared.clone(),
Expand All @@ -148,6 +149,7 @@ impl TrackerClient {
/// Registers an offer for a block with the given id.
/// Returns `true` if this block was offered for the first time (by any client) or `false` if
/// it's already been offered but not yet accepted or cancelled.
#[cfg_attr(not(test), expect(dead_code))]
pub fn register(&self, block_id: BlockId, state: OfferState) -> bool {
let mut inner = self.shared.inner.lock().unwrap();

Expand Down Expand Up @@ -235,6 +237,7 @@ pub(crate) struct BlockOffers {

impl BlockOffers {
/// Returns the next offer, waiting for one to appear if necessary.
#[cfg_attr(not(test), expect(dead_code))]
pub async fn next(&mut self) -> BlockOffer {
loop {
if let Some(offer) = self.try_next() {
Expand Down
139 changes: 100 additions & 39 deletions lib/src/network/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ use super::{
constants::RESPONSE_BATCH_SIZE,
debug_payload::{DebugRequest, DebugResponse},
message::{Message, Response, ResponseDisambiguator},
pending::{
EphemeralResponse, PendingRequest, PendingRequests, PersistableResponse, PreparedResponse,
},
pending::{EphemeralResponse, PendingRequests, PersistableResponse, PreparedResponse},
request_tracker::{PendingRequest, RequestTracker, RequestTrackerClient},
};
use crate::{
block_tracker::{BlockPromise, TrackerClient},
crypto::{sign::PublicKey, CacheHash, Hashable},
error::Result,
event::Payload,
network::{message::Request, request_tracker::MessageKey},
protocol::{
Block, BlockId, InnerNodes, LeafNodes, MultiBlockPresence, ProofError, RootNodeFilter,
UntrustedProof,
Expand All @@ -29,6 +29,7 @@ mod future {

pub(super) struct Client {
inner: Inner,
request_rx: mpsc::UnboundedReceiver<PendingRequest>,
response_rx: mpsc::Receiver<Response>,
}

Expand All @@ -37,50 +38,66 @@ impl Client {
vault: Vault,
message_tx: mpsc::UnboundedSender<Message>,
response_rx: mpsc::Receiver<Response>,
request_tracker: &RequestTracker,
) -> Self {
let (request_tracker, request_rx) = request_tracker.new_client();
let pending_requests = PendingRequests::new(vault.monitor.clone());
let block_tracker = vault.block_tracker.client();

let inner = Inner {
vault,
request_tracker,
pending_requests,
block_tracker,
message_tx,
};

Self { inner, response_rx }
Self {
inner,
request_rx,
response_rx,
}
}
}

impl Client {
pub async fn run(&mut self) -> Result<()> {
let Self { inner, response_rx } = self;
let Self {
inner,
request_rx,
response_rx,
} = self;

inner.run(response_rx).await
inner.run(request_rx, response_rx).await
}
}

struct Inner {
vault: Vault,
request_tracker: RequestTrackerClient,
pending_requests: PendingRequests,
#[expect(dead_code)]
block_tracker: TrackerClient,
message_tx: mpsc::UnboundedSender<Message>,
}

impl Inner {
async fn run(&mut self, response_rx: &mut mpsc::Receiver<Response>) -> Result<()> {
async fn run(
&mut self,
request_rx: &mut mpsc::UnboundedReceiver<PendingRequest>,
response_rx: &mut mpsc::Receiver<Response>,
) -> Result<()> {
select! {
result = self.handle_responses(response_rx) => result,
_ = self.send_requests(request_rx) => Ok(()),
_ = self.handle_available_block_offers() => Ok(()),
_ = self.handle_reload_index() => Ok(()),
}
}

fn send_request(&self, request: PendingRequest) {
if let Some(request) = self.pending_requests.insert(request) {
self.message_tx
.send(Message::Request(request))
.unwrap_or(());
async fn send_requests(&self, request_rx: &mut mpsc::UnboundedReceiver<PendingRequest>) {
while let Some(PendingRequest { request, .. }) = request_rx.recv().await {
self.message_tx.send(Message::Request(request)).ok();
}
}

Expand Down Expand Up @@ -221,16 +238,23 @@ impl Inner {
}

let hash = proof.hash;
let writer_id = proof.writer_id;
let status = writer.save_root_node(proof, &block_presence).await?;

tracing::debug!("Received root node - {status}");

if status.request_children() {
self.send_request(PendingRequest::ChildNodes(
hash,
ResponseDisambiguator::new(block_presence),
debug_payload.follow_up(),
));
self.request_tracker.success(
MessageKey::RootNode(writer_id),
vec![PendingRequest {
request: Request::ChildNodes(
hash,
ResponseDisambiguator::new(block_presence),
debug_payload.follow_up(),
),
block_presence,
}],
);
}

Ok(())
Expand All @@ -243,6 +267,7 @@ impl Inner {
nodes: CacheHash<InnerNodes>,
debug_payload: DebugResponse,
) -> Result<()> {
let hash = nodes.hash();
let total = nodes.len();
let status = writer.save_inner_nodes(nodes).await?;

Expand All @@ -252,13 +277,21 @@ impl Inner {
total
);

for node in status.new_children {
self.send_request(PendingRequest::ChildNodes(
node.hash,
ResponseDisambiguator::new(node.summary.block_presence),
debug_payload.follow_up(),
));
}
self.request_tracker.success(
MessageKey::ChildNodes(hash),
status
.new_children
.into_iter()
.map(|node| PendingRequest {
request: Request::ChildNodes(
node.hash,
ResponseDisambiguator::new(node.summary.block_presence),
debug_payload.follow_up(),
),
block_presence: node.summary.block_presence,
})
.collect(),
);

Ok(())
}
Expand All @@ -270,6 +303,7 @@ impl Inner {
nodes: CacheHash<LeafNodes>,
debug_payload: DebugResponse,
) -> Result<()> {
let hash = nodes.hash();
let total = nodes.len();
let status = writer.save_leaf_nodes(nodes).await?;

Expand All @@ -279,27 +313,45 @@ impl Inner {
total,
);

for (block_id, state) in status.new_block_offers {
self.block_tracker.register(block_id, state);
}
// TODO:
// for (block_id, state) in status.new_block_offers {
// self.block_tracker.register(block_id, state);
// }

self.request_tracker.success(
MessageKey::ChildNodes(hash),
status
.new_block_offers
.into_iter()
.map(|(block_id, _)| PendingRequest {
request: Request::Block(block_id, debug_payload.follow_up()),
block_presence: MultiBlockPresence::None,
})
.collect(),
);

Ok(())
}

#[instrument(skip_all, fields(id = ?block_id, ?debug_payload), err(Debug))]
async fn handle_block_offer(
&self,
reader: &mut ClientReader,
_reader: &mut ClientReader,
block_id: BlockId,
debug_payload: DebugResponse,
) -> Result<()> {
let Some(offer_state) = reader.load_block_offer_state(&block_id).await? else {
return Ok(());
};
// TODO:

// let Some(offer_state) = reader.load_block_offer_state(&block_id).await? else {
// return Ok(());
// };

// tracing::trace!(?offer_state, "Received block offer");

tracing::trace!(?offer_state, "Received block offer");
// self.block_tracker.register(block_id, offer_state);

self.block_tracker.register(block_id, offer_state);
self.request_tracker
.initial(Request::Block(block_id, debug_payload.follow_up()));

Ok(())
}
Expand All @@ -316,6 +368,9 @@ impl Inner {

tracing::trace!("Received block");

self.request_tracker
.success(MessageKey::Block(block.id), vec![]);

Ok(())
}

Expand Down Expand Up @@ -354,12 +409,14 @@ impl Inner {
}

async fn handle_available_block_offers(&self) {
let mut block_offers = self.block_tracker.offers();
// TODO:

loop {
let block_offer = block_offers.next().await;
self.send_request(PendingRequest::Block(block_offer, DebugRequest::start()));
}
// let mut block_offers = self.block_tracker.offers();

// loop {
// let block_offer = block_offers.next().await;
// self.send_request(PendingRequest::Block(block_offer, DebugRequest::start()));
// }
}

async fn handle_reload_index(&self) {
Expand Down Expand Up @@ -391,7 +448,8 @@ impl Inner {
// requested as soon as possible.
fn refresh_branches(&self, branches: impl IntoIterator<Item = PublicKey>) {
for branch_id in branches {
self.send_request(PendingRequest::RootNode(branch_id, DebugRequest::start()));
self.request_tracker
.initial(Request::RootNode(branch_id, DebugRequest::start()));
}
}

Expand Down Expand Up @@ -547,13 +605,16 @@ mod tests {

vault.block_tracker.set_request_mode(RequestMode::Lazy);

let request_tracker = RequestTracker::new();
let (request_tracker, _request_rx) = request_tracker.new_client();
let pending_requests = PendingRequests::new(vault.monitor.clone());
let block_tracker = vault.block_tracker.client();

let (message_tx, _message_rx) = mpsc::unbounded_channel();

let inner = Inner {
vault,
request_tracker,
pending_requests,
block_tracker,
message_tx,
Expand Down
Loading

0 comments on commit 355aacc

Please sign in to comment.