Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/wasmtime/src/runtime/component/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4311,7 +4311,7 @@ struct LiftResult {
/// This exists to minimize table lookups and the necessity to pass stores around mutably
/// for the common case of identifying the task to which a thread belongs.
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
struct QualifiedThreadId {
pub(crate) struct QualifiedThreadId {
task: TableId<GuestTask>,
thread: TableId<GuestThread>,
}
Expand Down Expand Up @@ -4806,7 +4806,7 @@ impl ConcurrentInstanceState {
}

#[derive(Debug, Copy, Clone)]
enum CurrentThread {
pub(crate) enum CurrentThread {
Guest(QualifiedThreadId),
Host(TableId<HostTask>),
None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::table::{TableDebug, TableId};
use super::{Event, GlobalErrorContextRefCount, Waitable, WaitableCommon};
use crate::component::concurrent::{ConcurrentState, WorkItem, tls};
use crate::component::concurrent::{ConcurrentState, QualifiedThreadId, WorkItem, tls};
use crate::component::func::{self, LiftContext, LowerContext};
use crate::component::matching::InstanceType;
use crate::component::types;
Expand Down Expand Up @@ -143,6 +143,7 @@ fn get_mut_by_index_from(
fn lower<T: func::Lower + Send + 'static, B: WriteBuffer<T>, U: 'static>(
mut store: StoreContextMut<U>,
instance: Instance,
caller_thread: QualifiedThreadId,
options: OptionsIndex,
ty: TransmitIndex,
address: usize,
Expand All @@ -151,11 +152,21 @@ fn lower<T: func::Lower + Send + 'static, B: WriteBuffer<T>, U: 'static>(
) -> Result<()> {
let count = buffer.remaining().len().min(count);

let lower = &mut if T::MAY_REQUIRE_REALLOC {
LowerContext::new
// If lowering may call realloc in the guest, then the guest may need
// to access its thread context, so we need to set the current thread before lowering
// and restore the old one afterward.
let (lower, old_thread) = if T::MAY_REQUIRE_REALLOC {
let old_thread = store.0.set_thread(caller_thread)?;
(
&mut LowerContext::new(store.as_context_mut(), options, instance),
Some(old_thread),
)
} else {
LowerContext::new_without_realloc
}(store.as_context_mut(), options, instance);
(
&mut LowerContext::new_without_realloc(store.as_context_mut(), options, instance),
None,
)
};

if address % usize::try_from(T::ALIGN32)? != 0 {
bail!("read pointer not aligned");
Expand All @@ -170,6 +181,10 @@ fn lower<T: func::Lower + Send + 'static, B: WriteBuffer<T>, U: 'static>(
T::linear_store_list_to_memory(lower, ty, address, &buffer.remaining()[..count])?;
}

if let Some(old_thread) = old_thread {
store.0.set_thread(old_thread)?;
}

buffer.skip(count);

Ok(())
Expand Down Expand Up @@ -2196,7 +2211,8 @@ enum ReadState {
/// The read end is owned by a guest task and a read is pending.
GuestReady {
ty: TransmitIndex,
caller: RuntimeComponentInstanceIndex,
caller_instance: RuntimeComponentInstanceIndex,
caller_thread: QualifiedThreadId,
flat_abi: Option<FlatAbi>,
instance: Instance,
options: OptionsIndex,
Expand Down Expand Up @@ -2930,7 +2946,8 @@ async fn write<D: 'static, P: Send + 'static, T: func::Lower + 'static, B: Write
count,
handle,
instance,
caller,
caller_instance,
caller_thread,
} => {
let guest_offset = match guest_offset {
Some(i) => i,
Expand All @@ -2952,6 +2969,7 @@ async fn write<D: 'static, P: Send + 'static, T: func::Lower + 'static, B: Write
lower::<T, B, D>(
store.as_context_mut(),
instance,
caller_thread,
options,
ty,
address + (T::SIZE32 * guest_offset),
Expand Down Expand Up @@ -3008,7 +3026,8 @@ async fn write<D: 'static, P: Send + 'static, T: func::Lower + 'static, B: Write
count,
handle,
instance,
caller,
caller_instance,
caller_thread,
};

crate::error::Ok(())
Expand Down Expand Up @@ -3178,11 +3197,12 @@ impl Instance {
self,
mut store: StoreContextMut<T>,
flat_abi: Option<FlatAbi>,
write_caller: RuntimeComponentInstanceIndex,
write_caller_instance: RuntimeComponentInstanceIndex,
write_ty: TransmitIndex,
write_options: OptionsIndex,
write_address: usize,
read_caller: RuntimeComponentInstanceIndex,
read_caller_instance: RuntimeComponentInstanceIndex,
read_caller_thread: QualifiedThreadId,
read_ty: TransmitIndex,
read_options: OptionsIndex,
read_address: usize,
Expand All @@ -3198,7 +3218,9 @@ impl Instance {

let payload = types[types[write_ty].ty].payload;

if write_caller == read_caller && !allow_intra_component_read_write(payload) {
if write_caller_instance == read_caller_instance
&& !allow_intra_component_read_write(payload)
{
bail!(
"cannot read from and write to intra-component future with non-numeric payload"
)
Expand Down Expand Up @@ -3228,6 +3250,10 @@ impl Instance {
.transpose()?;

if let Some(val) = val {
// Serializing the value may require calling the guest's realloc function, so we
// set the guest's thread context in case realloc requires it, and restore the original
// thread context after the copy is complete.
let old_thread = store.0.set_thread(read_caller_thread)?;
let lower = &mut LowerContext::new(store.as_context_mut(), read_options, self);
let types = lower.types;
let ty = match types[types[read_ty].ty].payload {
Expand All @@ -3240,10 +3266,11 @@ impl Instance {
&ValRaw::u32(read_address.try_into()?),
)?;
val.store(lower, ty, ptr)?;
store.0.set_thread(old_thread)?;
}
}
(TransmitIndex::Stream(write_ty), TransmitIndex::Stream(read_ty)) => {
if write_caller == read_caller
if write_caller_instance == read_caller_instance
&& !allow_intra_component_read_write(types[types[write_ty].ty].payload)
{
bail!(
Expand Down Expand Up @@ -3284,7 +3311,7 @@ impl Instance {
// SAFETY: Both `src` and `dst` have been validated
// above.
unsafe {
if write_caller == read_caller {
if write_caller_instance == read_caller_instance {
// If the same instance owns both ends of
// the stream, the source and destination
// buffers might overlap.
Expand Down Expand Up @@ -3326,6 +3353,10 @@ impl Instance {
let id = TableId::<TransmitHandle>::new(rep);
log::trace!("copy values {values:?} for {id:?}");

// Serializing the value may require calling the guest's realloc function, so we
// set the guest's thread context in case realloc requires it, and restore the original
// thread context after the copy is complete.
let old_thread = store.0.set_thread(read_caller_thread)?;
let lower = &mut LowerContext::new(store.as_context_mut(), read_options, self);
let ty = match lower.types[lower.types[read_ty].ty].payload {
Some(ty) => ty,
Expand All @@ -3348,6 +3379,7 @@ impl Instance {
value.store(lower, ty, ptr)?;
ptr += size
}
store.0.set_thread(old_thread)?;
}
}
_ => bail_bug!("mismatched transmit types in copy"),
Expand Down Expand Up @@ -3469,7 +3501,8 @@ impl Instance {
count: read_count,
handle: read_handle,
instance: read_instance,
caller: read_caller,
caller_instance: read_caller_instance,
caller_thread: read_caller_thread,
} => {
if flat_abi != read_flat_abi {
bail_bug!("expected flat ABI calculations to be the same");
Expand Down Expand Up @@ -3515,7 +3548,8 @@ impl Instance {
ty,
options,
address,
read_caller,
read_caller_instance,
read_caller_thread,
read_ty,
read_options,
read_address,
Expand Down Expand Up @@ -3557,7 +3591,8 @@ impl Instance {
count: read_count - count,
handle: read_handle,
instance: read_instance,
caller: read_caller,
caller_instance: read_caller_instance,
caller_thread: read_caller_thread,
};
}

Expand Down Expand Up @@ -3634,7 +3669,7 @@ impl Instance {
pub(super) fn guest_read<T: 'static>(
self,
mut store: StoreContextMut<T>,
caller: RuntimeComponentInstanceIndex,
caller_instance: RuntimeComponentInstanceIndex,
ty: TransmitIndex,
options: OptionsIndex,
flat_abi: Option<FlatAbi>,
Expand Down Expand Up @@ -3664,6 +3699,7 @@ impl Instance {
*state = TransmitLocalState::Busy;
let transmit_handle = TableId::<TransmitHandle>::new(rep);
let concurrent_state = store.0.concurrent_state_mut();
let caller_thread = concurrent_state.current_guest_thread()?;
let transmit_id = concurrent_state.get_mut(transmit_handle)?.state;
let transmit = concurrent_state.get_mut(transmit_id)?;
log::trace!(
Expand Down Expand Up @@ -3694,7 +3730,8 @@ impl Instance {
count,
handle,
instance: self,
caller,
caller_instance,
caller_thread,
};
Ok::<_, crate::Error>(())
};
Expand Down Expand Up @@ -3737,7 +3774,8 @@ impl Instance {
write_ty,
write_options,
write_address,
caller,
caller_instance,
caller_thread,
ty,
options,
address,
Expand Down
Loading
Loading