Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions bin/adkg-cli/src/adkg_dxkr23.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async fn adkg_pairing_out_g2<'a, E, S, TBT>(
adkg_scheme: S,
topic_transport: Arc<TBT>,
writer: Option<InMemoryWriter>,
mut rng: impl AdkgRng + 'static,
rng: impl AdkgRng + 'static,
) -> anyhow::Result<()>
where
E: Pairing,
Expand Down Expand Up @@ -171,7 +171,7 @@ where
group_config,
topic_transport,
adkg_scheme,
&mut rng,
rng,
tx_adkg_out,
)
.await
Expand Down Expand Up @@ -315,7 +315,7 @@ async fn adkg_dxkr23<S, TBT>(
group_config: GroupConfig,
topic_transport: Arc<TBT>,
adkg_scheme: S,
rng: &mut impl AdkgRng,
rng: impl AdkgRng + 'static,
out: oneshot::Sender<AdkgOutput<S::Curve>>,
) -> anyhow::Result<()>
where
Expand All @@ -325,9 +325,9 @@ where
S::ABAConfig: AbaConfig<'static, PartyId, Input = AbaCrainInput<S::Curve>>,
<S::ACSSConfig as AcssConfig<'static, S::Curve, PartyId>>::Output:
Into<ShareWithPoly<S::Curve>>,
TBT: TopicBasedTransport<Identity = PartyId>,
TBT: TopicBasedTransport<Identity = PartyId> + Send + Sync + 'static,
{
let mut adkg = adkg_scheme.new_adkg(
let adkg = adkg_scheme.new_adkg(
adkg_config.id,
group_config.n,
group_config.t,
Expand All @@ -336,6 +336,10 @@ where
pks.clone(),
)?;

let (adkg_start_tx, adkg_start_rx) = oneshot::channel();
let (adkg_stop_tx, adkg_stop_rx) = oneshot::channel();
let adkg_out = adkg.run(adkg_start_rx, adkg_stop_rx, rng, topic_transport);

// Calculate time to sleep before actively executing the adkg
let sleep_duration = (group_config.start_time - chrono::Utc::now())
.to_std() // TimeDelta to positive duration
Expand All @@ -353,11 +357,14 @@ where
"Executing ADKG with a timeout of {}",
humantime::format_duration(adkg_config.timeout)
);
if adkg_start_tx.send(()).is_err() {
anyhow::bail!("Failed to send ADKG start signal");
}

let res = tokio::select! {
output = adkg.start(rng, topic_transport) => {
let output = match output {
Ok(adkg_out) => {
output = adkg_out => {
let output: anyhow::Result<_> = match output {
Some(Ok(adkg_out)) => {
tracing::info!(used_sessions = ?adkg_out.used_sessions, "Successfully obtained secret key & output from ADKG");
if out.send(adkg_out).is_err() {
// fails if the receiver side is dropped early
Expand All @@ -368,9 +375,13 @@ where
tokio::time::sleep(adkg_config.grace_period).await;
Ok(())
}
Err(e) => {
Some(Err(e)) => {
tracing::error!("failed to obtain output from ADKG: {e:?}");
Err(e)
Err(e.into())
}
None => {
tracing::error!("failed to obtain output from ADKG: stopped before an output");
Err(anyhow!("ADKG stopped before output"))
}
};

Expand All @@ -384,9 +395,9 @@ where
};

tracing::warn!("Stopping ADKG...");
adkg.stop().await;
let _ = adkg_stop_tx.send(());

Ok(res??)
res?
}

/// Pairing-based DLEQ proof that there exists an s_j s.t. P_1 = [s_j] G_1 \land P_2 = [s_j] G_2,
Expand Down
76 changes: 66 additions & 10 deletions crates/adkg/src/adkg.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use futures::{FutureExt, pin_mut};
mod randex;
pub(crate) mod types;

Expand Down Expand Up @@ -190,6 +191,7 @@ where
ACSSConfig::Output: Into<ShareWithPoly<CG>>,
ABAConfig: AbaConfig<'static, PartyId, Input = AbaCrainInput<CG>>,
{
/// Start the ADKG immediately
pub async fn start<T>(
&mut self,
rng: &mut impl AdkgRng,
Expand All @@ -198,7 +200,47 @@ where
where
T: TopicBasedTransport<Identity = PartyId>,
{
self.execute(rng, transport).await
self.execute_internal(std::future::ready(()), rng, transport)
.await
}

/// An alternative way to execute the adkg by managing the lifecycle asynchronously.
/// The function executes the ADKG once `start` is resolved, and stops `stop` is resolved.
///
/// The function returns immediately with a future that resolves upon obtaining an output.
pub fn run<T>(
mut self,
start: impl Future + Send + 'static,
stop: impl Future<Output: Send> + Send + 'static,
mut rng: impl AdkgRng + 'static,
transport: Arc<T>,
) -> impl Future<Output = Option<Result<AdkgOutput<CG>, AdkgError>>>
where
T: TopicBasedTransport<Identity = PartyId> + Send + Sync + 'static,
{
let (output_tx, output_rx) = tokio::sync::oneshot::channel();
tokio::spawn({
async move {
pin_mut!(stop);

tokio::select! {
out = self.execute_internal(start, &mut rng, transport) => {
// Send output
let _ = output_tx.send(out);

// Wait for the stop signal
stop.await;
},
_ = &mut stop => (),
}

// stop signal received, stop ADKG
info!("Stop signal received, stopping ADKG");
self.stop().await;
}
});

output_rx.map(Result::ok)
}

pub async fn stop(mut self) {
Expand Down Expand Up @@ -259,8 +301,9 @@ where
}
}

async fn execute<T>(
async fn execute_internal<T>(
&mut self,
start_signal: impl Future,
rng: &mut impl AdkgRng,
transport: Arc<T>,
) -> Result<AdkgOutput<CG>, AdkgError>
Expand Down Expand Up @@ -304,25 +347,29 @@ where
.collect();

// Start the multi RBC, ACSS and ABA
state
.multi_acss
.lock()
.await
.start(s, rng, transport.clone());
state.multi_acss.lock().await.start(rng, transport.clone());
state
.multi_rbc
.lock()
.await
.start(rbc_predicates, transport.clone());
state.multi_aba.lock().await.start(rng, transport.clone());

// Get the ACSS sender
let acss_leader_sender = state
.multi_acss
.lock()
.await
.get_leader_sender()
.expect("failed to get acss leader sender");

// Get the node's own RBC
let leader_sender = state
let rbc_leader_sender = state
.multi_rbc
.lock()
.await
.get_leader_sender()
.expect("failed to get leader sender");
.expect("failed to get rbc leader sender");

// Create cancellation tokens for each subtask
let acss_cancel = self.cancel.child_token();
Expand All @@ -331,7 +378,7 @@ where

// Handler for the key set proposal phase. Manages the termination of
self.acss_task = Some(task::spawn(Self::acss_task(
leader_sender,
rbc_leader_sender,
state.clone(),
acss_cancel.clone(),
)));
Expand All @@ -344,6 +391,15 @@ where
// Upon termination of jth ABA
let abas_task = task::spawn(Self::aba_outputs_task(state.clone(), aba_cancel.clone()));

// Everything has been set-up, wait for the start signal
start_signal.await;
if acss_leader_sender.send(s).is_err() {
error!(
"ADKG main thread of node `{}` failed to set ACSS input",
self.id
);
}

// Try to join ABAs task, and obtain the final list of parties.
info!(
"ADKG main thread of node `{}` waiting on ABA task to complete",
Expand Down
54 changes: 40 additions & 14 deletions crates/adkg/src/vss/acss/multi_acss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,21 @@ where
acss_config: Arc<ACSSConfig>,

// Attributes used to manage the subtasks
acss_tasks: JoinSet<(SessionId, Result<(), ACSSConfig::Error>)>, // set of acss tasks
acss_tasks: JoinSet<(SessionId, Result<(), MultiAcssError>)>, // set of acss tasks
acss_receivers: Vec<Option<oneshot::Receiver<ACSSConfig::Output>>>,
acss_leader_sender: Option<oneshot::Sender<ACSSConfig::Input>>, // set the leader input
cancels: Vec<CancellationToken>,
}

#[derive(thiserror::Error, Debug)]
pub enum MultiAcssError {
#[error(transparent)]
Acss(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),

#[error("failed to get ACSS input from channel: sender dropped")]
AcssInputDropped,
}

impl<CG, ACSSConfig> MultiAcss<CG, ACSSConfig>
where
CG: CurveGroup,
Expand All @@ -42,12 +52,14 @@ where
acss_config,
acss_tasks: JoinSet::new(),
acss_receivers: vec![],
acss_leader_sender: None,
cancels,
}
}

/// Start the n parallel ACSS instances in the background.
pub fn start<T>(&mut self, s: ACSSConfig::Input, rng: &mut impl AdkgRng, transport: T)
/// Returns a channel used to transmit the ACSS secret.
pub fn start<T>(&mut self, rng: &mut impl AdkgRng, transport: T)
where
T: TopicBasedTransport<Identity = PartyId>,
{
Expand All @@ -57,7 +69,11 @@ where
.map(|(sender, receiver)| (sender, Some(receiver)))
.collect();
self.acss_receivers = receivers;
let mut s = Some(s); // need an option for interior mutability...

// Create one channel for the ACSS input
let (input_tx, input_rx) = oneshot::channel();
self.acss_leader_sender = Some(input_tx);
let mut input_rx = Some(input_rx); // need an option for interior mutability...

for (sid, cancel, sender) in izip!(
SessionId::iter_all(self.n_instances),
Expand All @@ -77,26 +93,31 @@ where
// s is not cloneable, and we only want to move it when sid == node_id
// In order to not move s due to the async move below, we take() s only once
// here, and use None when sid != node_id. This allows to move the value only once.
let s = if sid == node_id { s.take() } else { None };
let mut input_rx = if sid == node_id {
input_rx.take()
} else {
None
};

let mut rng = rng
.get(AdkgRngType::Acss(sid))
.expect("failed to obtain acss rng");
async move {
// Start the acss tasks
let res = if sid == node_id {
acss.deal(
s.expect("can only enter once"), // s must be Some(.) since sid == node_id
cancellation_token,
sender,
&mut rng,
)
.instrument(tracing::warn_span!("ACSS::deal", ?sid))
.await
if let Ok(s) = input_rx.take().expect("to enter once").await {
acss.deal(s, cancellation_token, sender, &mut rng)
.instrument(tracing::warn_span!("ACSS::deal", ?sid))
.await
.map_err(|e| MultiAcssError::Acss(e.into()))
} else {
Err(MultiAcssError::AcssInputDropped)
}
} else {
acss.get_share(sid.into(), cancellation_token, sender, &mut rng)
.instrument(tracing::warn_span!("ACSS::get_share", ?sid))
.await
.map_err(|e| MultiAcssError::Acss(e.into()))
};

(sid, res)
Expand All @@ -105,6 +126,11 @@ where
}
}

/// Get the oneshot sender used to set the leader output of the ACSS where self.node_id == sid
pub fn get_leader_sender(&mut self) -> Option<oneshot::Sender<ACSSConfig::Input>> {
self.acss_leader_sender.take()
}

/// Create an iterator over the remaining ACSS outputs.
pub fn iter_remaining_outputs(
&mut self,
Expand All @@ -124,11 +150,11 @@ where
}

/// Stop the ACSS instances and return Ok(()) if no errors were output, otherwise, return the identifier of failed instances and their errors.
pub async fn stop(self) -> Result<(), Vec<(SessionId, ACSSConfig::Error)>> {
pub async fn stop(self) -> Result<(), Vec<(SessionId, MultiAcssError)>> {
// Signal cancellation through each of the cancellation tokens
self.cancels.iter().for_each(|cancel| cancel.cancel());

let errors: Vec<(SessionId, ACSSConfig::Error)> = self
let errors: Vec<(SessionId, MultiAcssError)> = self
.acss_tasks
.join_all()
.await
Expand Down
Loading