diff --git a/vm/devices/vmbus/vmbus_channel/src/channel.rs b/vm/devices/vmbus/vmbus_channel/src/channel.rs index 568ab2a390..641a1c3856 100644 --- a/vm/devices/vmbus/vmbus_channel/src/channel.rs +++ b/vm/devices/vmbus/vmbus_channel/src/channel.rs @@ -16,6 +16,7 @@ use crate::gpadl::GpadlMapView; use anyhow::Context; use async_trait::async_trait; use futures::StreamExt; +use futures::future::join_all; use futures::stream::SelectAll; use futures::stream::select; use inspect::Inspect; @@ -513,7 +514,22 @@ impl Device { } } } - // Revoke the channel. + + // Revoke all subchannels before the primary channel, so that the + // guest sees the rescind for every subchannel before the rescind for + // the primary. Issuing `Revoke` RPCs concurrently and awaiting them + // all via `join_all` lets `vmbus_server` rescind the subchannels in + // any order, but guarantees they have all been emitted before the + // primary's sender is dropped below. + let subchannel_senders = self.server_requests.split_off(1); + join_all( + subchannel_senders + .into_iter() + .map(|s| s.call(ChannelServerRequest::Revoke, ())), + ) + .await; + + // Revoke the primary channel by dropping its sender. drop(self.server_requests); // Wait for the revokes to finish. // When vmbus (sub)channels are closed, `self.requests` ends up with stale diff --git a/vm/devices/vmbus/vmbus_server/src/tests.rs b/vm/devices/vmbus/vmbus_server/src/tests.rs index 53c07d38d0..86d7e31a35 100644 --- a/vm/devices/vmbus/vmbus_server/src/tests.rs +++ b/vm/devices/vmbus/vmbus_server/src/tests.rs @@ -14,6 +14,7 @@ use protocol::UserDefinedData; use std::time::Duration; use test_with_tracing::test; use vmbus_channel::bus::OfferParams; +use vmbus_channel::channel::ChannelControl; use vmbus_channel::channel::ChannelOpenError; use vmbus_channel::channel::DeviceResources; use vmbus_channel::channel::SaveRestoreVmbusDevice; @@ -989,3 +990,158 @@ async fn test_server_monitor_page_helper( } } } + +/// A `VmbusDevice` that supports a configurable number of subchannels and +/// exposes its `ChannelControl` so the test driver can request subchannel +/// offers at the desired point in the test. +#[derive(InspectMut)] +struct SubchannelTestDevice { + #[inspect(skip)] + id: u32, + #[inspect(skip)] + max_subchannels: u16, + #[inspect(skip)] + channel_control: Arc>>, +} + +impl SubchannelTestDevice { + fn new(id: u32, max_subchannels: u16) -> (Self, Arc>>) { + let channel_control = Arc::new(Mutex::new(None)); + ( + Self { + id, + max_subchannels, + channel_control: channel_control.clone(), + }, + channel_control, + ) + } +} + +#[async_trait] +impl VmbusDevice for SubchannelTestDevice { + fn offer(&self) -> OfferParams { + let guid = Guid { + data1: self.id, + ..Guid::ZERO + }; + OfferParams { + interface_name: "subchannel-test".into(), + instance_id: guid, + interface_id: guid, + channel_type: vmbus_channel::bus::ChannelType::Device { + pipe_packets: false, + }, + ..Default::default() + } + } + + fn max_subchannels(&self) -> u16 { + self.max_subchannels + } + + fn install(&mut self, resources: DeviceResources) { + *self.channel_control.lock() = Some(resources.channel_control); + } + + async fn open( + &mut self, + _channel_idx: u16, + _open_request: &OpenRequest, + ) -> Result<(), ChannelOpenError> { + Ok(()) + } + + async fn close(&mut self, _channel_idx: u16) {} + + async fn retarget_vp(&mut self, _channel_idx: u16, _target_vp: u32) {} + + fn start(&mut self) {} + + async fn stop(&mut self) {} + + fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> { + None + } +} + +/// Verifies that when a device with subchannels is revoked, the guest receives +/// the `RescindChannelOffer` messages for all of the subchannels before +/// receiving the rescind for the primary channel. +#[async_test] +async fn test_revoke_subchannels_before_primary(spawner: DefaultDriver) { + const NUM_SUBCHANNELS: u16 = 3; + + let mut env = TestEnv::new(spawner.clone()); + + let (device, channel_control) = SubchannelTestDevice::new(1, NUM_SUBCHANNELS); + let control = env.vmbus.control(); + let handle = offer_channel(&spawner, control.as_ref(), device) + .await + .expect("offer failed"); + + env.vmbus.start(); + + // Initiate contact and request offers. Only the primary channel is + // offered at this point (subchannels are enabled below). + env.initiate_contact( + protocol::Version::Copper, + protocol::FeatureFlags::new(), + false, + false, + ); + env.expect_response(protocol::MessageType::VERSION_RESPONSE) + .await; + env.synic.send_message(protocol::RequestOffers {}); + + let primary_offer = env.get_response::().await; + assert_eq!(primary_offer.subchannel_index, 0); + let primary_id = primary_offer.channel_id; + env.expect_response(protocol::MessageType::ALL_OFFERS_DELIVERED) + .await; + + // Enable subchannels; the device task will offer them and we should see + // NUM_SUBCHANNELS additional `OfferChannel` messages. + channel_control + .lock() + .as_ref() + .expect("channel control installed") + .enable_subchannels(NUM_SUBCHANNELS) + .expect("enable_subchannels"); + + let mut subchannel_ids = Vec::new(); + for _ in 0..NUM_SUBCHANNELS { + let offer = env.get_response::().await; + assert_ne!(offer.subchannel_index, 0); + subchannel_ids.push(offer.channel_id); + } + + // Revoke the channel via the handle. This drives the device task in + // `vmbus_channel::channel::Device::run_channel` through the teardown path + // under test: subchannels must be revoked before the primary. + handle.revoke().await; + + // Drain NUM_SUBCHANNELS + 1 `RescindChannelOffer` messages. The primary's + // rescind must arrive only after every subchannel's rescind has been + // delivered. + let mut remaining_subs: std::collections::HashSet<_> = subchannel_ids.iter().copied().collect(); + let mut seen_primary = false; + for _ in 0..(NUM_SUBCHANNELS as usize + 1) { + let msg = env.get_response::().await; + if msg.channel_id == primary_id { + assert!( + remaining_subs.is_empty(), + "primary rescinded while subchannels still pending: {:?}", + remaining_subs, + ); + seen_primary = true; + } else { + assert!( + remaining_subs.remove(&msg.channel_id), + "unexpected rescind for channel id {:?}", + msg.channel_id, + ); + } + } + assert!(seen_primary, "primary rescind not observed"); +}