From 2626a2e95c0999e1f4bc52f31ac1807a6fe9ac0c Mon Sep 17 00:00:00 2001 From: erfrimod Date: Fri, 22 May 2026 17:46:48 -0700 Subject: [PATCH 1/2] fix subchannel panic on endpoint restart --- vm/devices/net/netvsp/src/lib.rs | 153 ++++++++++++----- vm/devices/net/netvsp/src/test.rs | 269 +++++++++++++++++++++++++++++- 2 files changed, 380 insertions(+), 42 deletions(-) diff --git a/vm/devices/net/netvsp/src/lib.rs b/vm/devices/net/netvsp/src/lib.rs index d6fc9e62a2..8686ca5f55 100644 --- a/vm/devices/net/netvsp/src/lib.rs +++ b/vm/devices/net/netvsp/src/lib.rs @@ -158,7 +158,8 @@ enum CoordinatorMessage { Update(CoordinatorMessageUpdateType), /// Restart endpoints and resume processing. This will also attempt to set VF and data path state to match current /// expectations. - Restart, + /// Identifies the channel that requested the restart. 0 = primary; >0 = sub-channel + Restart { channel_idx: u16 }, /// Start a timer. StartTimer(Instant), } @@ -1547,7 +1548,9 @@ impl Nic { driver_builder.run_on_target(!self.adapter.tx_fast_completions); #[expect(clippy::disallowed_methods)] // TODO - let (send, recv) = mpsc::channel(1); + // Capacity equals the number of workers. + // Each channel (primary or subchannel) can send at most one message. + let (send, recv) = mpsc::channel(self.adapter.max_queues as usize); self.coordinator_send = Some(send); self.coordinator.insert( &self.adapter.driver, @@ -3065,7 +3068,7 @@ impl NetChannel { // Restart the endpoint if the OID changed some critical // endpoint property. if restart_endpoint { - self.restart = Some(CoordinatorMessage::Restart); + self.restart = Some(CoordinatorMessage::Restart { channel_idx: 0 }); } if let Some(filter) = packet_filter { if self.packet_filter != filter { @@ -3267,7 +3270,7 @@ impl NetChannel { guest_vf_state: guest_vf, filter_state: packet_filter, })); - } else if let Some(CoordinatorMessage::Restart) = self.restart { + } else if let Some(CoordinatorMessage::Restart { .. }) = self.restart { // If a restart message is pending, do nothing. // A restart will try to switch the data path based on primary.guest_vf_state. // A restart will apply packet filter changes. @@ -4051,30 +4054,14 @@ impl Coordinator { state: &mut CoordinatorState, ) -> Result<(), task_control::Cancelled> { loop { + // Drain any messages already queued on `recv` to decide + // whether to run the restart cycle. + self.drain_pending_messages(state).await; + + // If anything (primary, sub, or endpoint) requested a restart, do + // it. Then loop back to re-drain in case more messages arrived. if self.restart { - stop.until_stopped(self.stop_workers()).await?; - // The queue restart operation is not restartable, so do not - // poll on `stop` here. - if let Err(err) = self - .restart_queues(state) - .instrument(tracing::info_span!("netvsp_restart_queues")) - .await - { - tracing::error!( - error = &err as &dyn std::error::Error, - "failed to restart queues" - ); - } - if let Some(primary) = self.primary_mut() { - primary.is_data_path_switched = - state.endpoint.get_data_path_to_guest_vf().await.ok(); - tracing::info!( - is_data_path_switched = primary.is_data_path_switched, - "Query data path state" - ); - } - self.restore_guest_vf_state(state).await; - self.restart = false; + self.run_restart_cycle(stop, state).await?; } // Ensure that all workers except the primary are started. The @@ -4228,10 +4215,81 @@ impl Coordinator { Ok(()) } + /// Called at the top of each iteration of [`Self::process`] loop. + /// Coalesce Sub-channel restart requests into `self.restart = true`. + /// Any Primary message landing concurrently with a `Restart` is + /// handled prior to running the restart cycle. + async fn drain_pending_messages(&mut self, state: &mut CoordinatorState) { + while let Ok(Some(msg)) = self.recv.try_next() { + self.handle_coordinator_message(msg, state).await; + } + } + + /// Called from the [`Self::process`] loop when a `Restart` message has + /// been observed from any channel. + async fn run_restart_cycle( + &mut self, + stop: &mut StopTask<'_>, + state: &mut CoordinatorState, + ) -> Result<(), task_control::Cancelled> { + stop.until_stopped(self.stop_workers()).await?; + + // All workers are stopped and cannot push new messages. + // Drain any messages that arrived prior to the stop. + while let Ok(Some(msg)) = self.recv.try_next() { + match msg { + CoordinatorMessage::Restart { .. } => { + // Discarding any additional Restart messages. + } + // Ensure any non-restart message from the Primary is + // handled prior to restarting the workers. + other => { + self.handle_primary_message(other, state).await; + } + } + } + + // The queue restart operation is not restartable; do not poll on stop here. + if let Err(err) = self + .restart_queues(state) + .instrument(tracing::info_span!("netvsp_restart_queues")) + .await + { + tracing::error!( + error = &err as &dyn std::error::Error, + "failed to restart queues" + ); + } + if let Some(primary) = self.primary_mut() { + primary.is_data_path_switched = state.endpoint.get_data_path_to_guest_vf().await.ok(); + tracing::info!( + is_data_path_switched = primary.is_data_path_switched, + "Query data path state" + ); + } + self.restore_guest_vf_state(state).await; + self.restart = false; + Ok(()) + } + async fn handle_coordinator_message( &mut self, msg: CoordinatorMessage, state: &mut CoordinatorState, + ) { + match msg { + CoordinatorMessage::Restart { channel_idx } if channel_idx != 0 => { + tracing::info!(channel_idx, "sub-channel triggered restart"); + self.restart = true; + } + _ => self.handle_primary_message(msg, state).await, + } + } + + async fn handle_primary_message( + &mut self, + msg: CoordinatorMessage, + state: &mut CoordinatorState, ) { self.workers[0].stop().await; if let Some(worker) = self.workers[0].state_mut() { @@ -4269,7 +4327,11 @@ impl Coordinator { CoordinatorMessage::StartTimer(deadline) => { self.sleep_deadline = Some(deadline); } - CoordinatorMessage::Restart => self.restart = true, + CoordinatorMessage::Restart { channel_idx } => { + assert_eq!(channel_idx, 0); + tracing::info!(channel_idx, "primary-channel triggered restart"); + self.restart = true; + } } } @@ -4806,7 +4868,9 @@ impl Worker { }; // Wake up the coordinator task to start the queues. - let _ = self.coordinator_send.try_send(CoordinatorMessage::Restart); + let _ = self + .coordinator_send + .try_send(CoordinatorMessage::Restart { channel_idx: 0 }); tracelimit::info_ratelimited!("network initialized"); self.state = WorkerState::WaitingForCoordinator(Some(state)); @@ -4846,23 +4910,30 @@ impl Worker { Err(WorkerError::EndpointRequiresQueueRestart(err)) => { tracelimit::warn_ratelimited!( err = err.as_ref() as &dyn std::error::Error, + channel_idx = self.channel_idx, "Endpoint requires queues to restart", ); - CoordinatorMessage::Restart + CoordinatorMessage::Restart { + channel_idx: self.channel_idx, + } } Err(err) => return Err(err), }; - let WorkerState::Ready(ready) = std::mem::replace( - &mut self.state, - WorkerState::WaitingForCoordinator(None), - ) else { - unreachable!("must be running in ready state") - }; - let _ = std::mem::replace( - &mut self.state, - WorkerState::WaitingForCoordinator(Some(ready)), - ); + // Only the Primary channel transitions to `WaitingForCoordinator`. + // Sub-channels stay in `Ready(_)`. + if self.channel_idx == 0 { + let WorkerState::Ready(ready) = std::mem::replace( + &mut self.state, + WorkerState::WaitingForCoordinator(None), + ) else { + unreachable!("must be running in ready state") + }; + let _ = std::mem::replace( + &mut self.state, + WorkerState::WaitingForCoordinator(Some(ready)), + ); + } self.coordinator_send .try_send(msg) .map_err(WorkerError::CoordinatorMessageSendFailed)?; @@ -5637,7 +5708,7 @@ impl NetChannel { let primary = state.primary.as_mut().unwrap(); primary.requested_num_queues = subchannel_count as u16 + 1; primary.tx_spread_sent = false; - self.restart = Some(CoordinatorMessage::Restart); + self.restart = Some(CoordinatorMessage::Restart { channel_idx: 0 }); } } PacketData::RevokeReceiveBuffer(protocol::Message1RevokeReceiveBuffer { id }) diff --git a/vm/devices/net/netvsp/src/test.rs b/vm/devices/net/netvsp/src/test.rs index ad0f1a2902..c0b1a041ae 100644 --- a/vm/devices/net/netvsp/src/test.rs +++ b/vm/devices/net/netvsp/src/test.rs @@ -148,6 +148,9 @@ struct TestNicEndpointState { /// `(false, N)`, leaving packets in-flight. pub sync_tx: bool, pub tx_metadata: Vec, + /// Per-queue one-shot trigger: when `tx_restart_triggers[idx]` is true, + /// the next `tx_poll` call on queue `idx` returns `TxError::TryRestart(_)`. + pub tx_restart_triggers: Vec, } impl TestNicEndpointState { @@ -162,6 +165,7 @@ impl TestNicEndpointState { queues: Vec::new(), sync_tx: true, tx_metadata: Vec::new(), + tx_restart_triggers: Vec::new(), })) } @@ -185,6 +189,15 @@ impl TestNicEndpointState { pub fn send_rx_with_metadata(&self, queue_idx: usize, data: Vec, metadata: RxMetadata) { self.queues[queue_idx].send((data, metadata)); } + + /// Arm a one-shot trigger so the next `tx_poll` on `queue_idx` returns TryRestart. + pub fn trigger_tx_restart(&mut self, queue_idx: usize) { + assert!( + queue_idx < self.tx_restart_triggers.len(), + "trigger_tx_restart called before get_queues populated triggers" + ); + self.tx_restart_triggers[queue_idx] = true; + } } struct TestNicEndpointInner { @@ -260,13 +273,15 @@ impl net_backend::Endpoint for TestNicEndpoint { .is_none_or(|s| s.lock().sync_tx); let senders = config .into_iter() - .map(|config| { + .enumerate() + .map(|(queue_idx, config)| { let (tx, rx) = mesh::channel(); queues.push(Box::new(TestNicQueue::new( config, rx, sync_tx, inner.endpoint_state.clone(), + queue_idx, ))); tx }) @@ -274,6 +289,7 @@ impl net_backend::Endpoint for TestNicEndpoint { if let Some(endpoint_state) = &inner.endpoint_state { let mut locked_data = endpoint_state.lock(); + locked_data.tx_restart_triggers = vec![false; senders.len()]; locked_data.queues = senders; } Ok(()) @@ -365,6 +381,7 @@ struct TestNicQueue { #[inspect(skip)] next_rx_packet: Option<(Vec, RxMetadata)>, sync_tx: bool, + queue_idx: usize, } impl TestNicQueue { @@ -373,6 +390,7 @@ impl TestNicQueue { rx: mesh::Receiver<(Vec, RxMetadata)>, sync_tx: bool, endpoint_state: Option>>, + queue_idx: usize, ) -> Self { Self { rx_ids: VecDeque::new(), @@ -380,6 +398,7 @@ impl TestNicQueue { endpoint_state, next_rx_packet: None, sync_tx, + queue_idx, } } } @@ -471,6 +490,21 @@ impl NetQueue for TestNicQueue { _pool: &mut dyn BufferAccess, _done: &mut [TxId], ) -> Result { + if let Some(endpoint_state) = &self.endpoint_state { + let mut locked = endpoint_state.lock(); + if locked + .tx_restart_triggers + .get(self.queue_idx) + .copied() + .unwrap_or(false) + { + locked.tx_restart_triggers[self.queue_idx] = false; + return Err(TxError::TryRestart(anyhow::anyhow!( + "test-injected tx_poll restart on queue {}", + self.queue_idx + ))); + } + } Ok(0) } } @@ -7410,3 +7444,236 @@ async fn vlan_rx_counter_increments(driver: DefaultDriver) { "netvsp should count 1 VLAN RX packet" ); } + +#[async_test] +async fn subchannel_tx_restart(driver: DefaultDriver) { + const TOTAL_QUEUES: u32 = 4; + let endpoint_state = TestNicEndpointState::new(); + let endpoint = TestNicEndpoint::new(Some(endpoint_state.clone())); + let nic = Nic::builder().build( + &VmTaskDriverSource::new(SingleDriverBackend::new(driver.clone())), + Guid::new_random(), + Box::new(endpoint), + [1, 2, 3, 4, 5, 6].into(), + 0, + ); + + let mut nic = TestNicDevice::new_with_nic(&driver, nic).await; + nic.start_vmbus_channel(); + let mut channel = nic.connect_vmbus_channel().await; + channel + .initialize( + TOTAL_QUEUES as usize - 1, + protocol::NdisConfigCapabilities::new(), + ) + .await; + + // RNDIS initialize. + channel + .send_rndis_control_message( + rndisprot::MESSAGE_TYPE_INITIALIZE_MSG, + rndisprot::InitializeRequest { + request_id: 1, + major_version: rndisprot::MAJOR_VERSION, + minor_version: rndisprot::MINOR_VERSION, + max_transfer_size: 0, + }, + &[], + ) + .await; + let _: rndisprot::InitializeComplete = channel + .read_rndis_control_message(rndisprot::MESSAGE_TYPE_INITIALIZE_CMPLT) + .await + .unwrap(); + + // Allocate the sub-channels. + let alloc_message = NvspMessage { + header: protocol::MessageHeader { + message_type: protocol::MESSAGE5_TYPE_SUB_CHANNEL, + }, + data: protocol::Message5SubchannelRequest { + operation: protocol::SubchannelOperation::ALLOCATE, + num_sub_channels: TOTAL_QUEUES - 1, + }, + padding: &[], + }; + channel + .write(OutgoingPacket { + transaction_id: 123, + packet_type: OutgoingPacketType::InBandWithCompletion, + payload: &alloc_message.payload(), + }) + .await; + channel + .read_with(|packet| match packet { + IncomingPacket::Completion(completion) => { + let mut reader = completion.reader(); + let header: protocol::MessageHeader = reader.read_plain().unwrap(); + assert_eq!(header.message_type, protocol::MESSAGE5_TYPE_SUB_CHANNEL); + let completion_data: protocol::Message5SubchannelComplete = + reader.read_plain().unwrap(); + assert_eq!(completion_data.status, protocol::Status::SUCCESS); + assert_eq!(completion_data.num_sub_channels, TOTAL_QUEUES - 1); + } + _ => panic!("Unexpected packet"), + }) + .await + .expect("sub-channel allocation completion"); + + for idx in 1..TOTAL_QUEUES { + channel.connect_subchannel(idx).await; + } + + // Drain the indirection-table data packet from the primary and complete it. + let transaction_id = channel + .read_with(|packet| match packet { + IncomingPacket::Data(packet) => { + let mut reader = packet.reader(); + let header: protocol::MessageHeader = reader.read_plain().unwrap(); + assert_eq!( + header.message_type, + protocol::MESSAGE5_TYPE_SEND_INDIRECTION_TABLE + ); + packet.transaction_id() + } + _ => panic!("Unexpected packet"), + }) + .await + .expect("indirection table message"); + if let Some(transaction_id) = transaction_id { + channel + .write(OutgoingPacket { + transaction_id, + packet_type: OutgoingPacketType::Completion, + payload: &NvspMessage { + header: protocol::MessageHeader { + message_type: protocol::MESSAGE1_TYPE_SEND_RNDIS_PACKET_COMPLETE, + }, + data: protocol::Message1SendRndisPacketComplete { + status: protocol::Status::SUCCESS, + }, + padding: &[], + } + .payload(), + }) + .await; + } + + // Enable a packet filter so RX traffic is delivered to the guest. + let request_id = 456; + channel + .send_rndis_control_message( + rndisprot::MESSAGE_TYPE_SET_MSG, + rndisprot::SetRequest { + request_id, + oid: rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER, + information_buffer_length: size_of::() as u32, + information_buffer_offset: size_of::() as u32, + device_vc_handle: 0, + }, + &rndisprot::NPROTO_PACKET_FILTER.to_le_bytes(), + ) + .await; + let set_complete: rndisprot::SetComplete = channel + .read_rndis_control_message(rndisprot::MESSAGE_TYPE_SET_CMPLT) + .await + .unwrap(); + assert_eq!(set_complete.request_id, request_id); + assert_eq!(set_complete.status, rndisprot::STATUS_SUCCESS); + + // Send a SECOND OID_GEN_CURRENT_PACKET_FILTER SET with a different + // filter value. The primary worker handles this OID inline, then pushes + // a `CoordinatorMessage::Update` to the coordinator. + const NEW_FILTER: u32 = rndisprot::NDIS_PACKET_TYPE_DIRECTED; + let request_id_filter_change = 457; + channel + .send_rndis_control_message( + rndisprot::MESSAGE_TYPE_SET_MSG, + rndisprot::SetRequest { + request_id: request_id_filter_change, + oid: rndisprot::Oid::OID_GEN_CURRENT_PACKET_FILTER, + information_buffer_length: size_of::() as u32, + information_buffer_offset: size_of::() as u32, + device_vc_handle: 0, + }, + &NEW_FILTER.to_le_bytes(), + ) + .await; + let set_complete: rndisprot::SetComplete = channel + .read_rndis_control_message(rndisprot::MESSAGE_TYPE_SET_CMPLT) + .await + .unwrap(); + assert_eq!(set_complete.request_id, request_id_filter_change); + assert_eq!(set_complete.status, rndisprot::STATUS_SUCCESS); + + let stop_before = endpoint_state.lock().stop_endpoint_counter; + + // Arm a one-shot TryRestart on every sub-channel queue and wake each + // one by routing an RX packet through it. Each sub-channel's main_loop + // will call tx_poll, hit TryRestart, and route a Restart to the coordinator. + { + let mut locked = endpoint_state.lock(); + for idx in 1..TOTAL_QUEUES as usize { + locked.trigger_tx_restart(idx); + locked.send_rx(idx, vec![0xAA + idx as u8; 60]); + } + } + + // Wait for the coordinator to observe at least one Restart message and + // run restart_queues (which calls endpoint.stop()). + // All three `Restart` messages coalesce into a single restart cycle. + let deadline = std::time::Instant::now() + Duration::from_secs(5); + loop { + if endpoint_state.lock().stop_endpoint_counter > stop_before { + break; + } + if std::time::Instant::now() >= deadline { + panic!( + "restart_queues did not run within timeout (stop_before={}, current={})", + stop_before, + endpoint_state.lock().stop_endpoint_counter + ); + } + // Yield to let the workers and coordinator make progress. + let mut ctx = mesh::CancelContext::new().with_timeout(Duration::from_millis(50)); + let _ = ctx.until_cancelled(pending::<()>()).await; + } + + // Bound the number of restart cycles. + let stop_after = endpoint_state.lock().stop_endpoint_counter; + let cycles = stop_after - stop_before; + assert_eq!( + cycles, 1, + "expected exactly 1 restart cycle (pass B coalesces all in-flight sub-channel TryRestarts), got {cycles}", + ); + + // The `Update` filter change being visible on every sub-channel proves + // the Primary channel message was not lost. + for idx in 1..TOTAL_QUEUES as usize { + let current = + read_netvsp_counter(&channel.nic.channel, &format!("queues/{idx}/packet_filter")).await + as u32; + assert_eq!( + current, NEW_FILTER, + "sub-channel {idx} packet_filter not propagated before restart cycle finished", + ); + } + + // Deliver another RX packet on every sub-channel and verify each worker + // is alive and functional. + { + let locked = endpoint_state.lock(); + for idx in 1..TOTAL_QUEUES as usize { + locked.send_rx(idx, vec![0xBB + idx as u8; 60]); + } + } + for idx in 1..TOTAL_QUEUES { + channel + .read_subchannel_with(idx, |packet| match packet { + IncomingPacket::Data(_) => (), + _ => panic!("Unexpected packet on sub-channel {idx}"), + }) + .await + .unwrap_or_else(|_| panic!("sub-channel {idx} RX packet after restart cycle")); + } +} From cb6fec231fa6f9e1b451af6d4bc99cf9c67dfab4 Mon Sep 17 00:00:00 2001 From: erfrimod Date: Fri, 22 May 2026 18:19:15 -0700 Subject: [PATCH 2/2] drain in run_restart_cycle uses same handle_coordinator_message to ensure a primary channel that sent restart is moved back to Ready --- vm/devices/net/netvsp/src/lib.rs | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/vm/devices/net/netvsp/src/lib.rs b/vm/devices/net/netvsp/src/lib.rs index 8686ca5f55..8dc5b3a963 100644 --- a/vm/devices/net/netvsp/src/lib.rs +++ b/vm/devices/net/netvsp/src/lib.rs @@ -4236,18 +4236,7 @@ impl Coordinator { // All workers are stopped and cannot push new messages. // Drain any messages that arrived prior to the stop. - while let Ok(Some(msg)) = self.recv.try_next() { - match msg { - CoordinatorMessage::Restart { .. } => { - // Discarding any additional Restart messages. - } - // Ensure any non-restart message from the Primary is - // handled prior to restarting the workers. - other => { - self.handle_primary_message(other, state).await; - } - } - } + self.drain_pending_messages(state).await; // The queue restart operation is not restartable; do not poll on stop here. if let Err(err) = self