diff --git a/CMakeLists.txt b/CMakeLists.txt
index 2834728e3..b840aa577 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -74,6 +74,8 @@ find_nuget_package(Microsoft.WSL.DeviceHost WSL_DEVICE_HOST /build/native)
find_nuget_package(Microsoft.WSL.Kernel KERNEL /build/native)
find_nuget_package(Microsoft.WSL.bsdtar BSDTARD /build/native/bin)
find_nuget_package(Microsoft.WSL.LinuxSdk LINUXSDK /)
+#find_nuget_package(Microsoft.WSL.BPF WSL_BPF /)
+set(WSL_BPF_SOURCE_DIR "${CMAKE_SOURCE_DIR}/packages/Microsoft.WSL.BPF.1.0.0")
find_nuget_package(Microsoft.WSL.TestDistro TEST_DISTRO /)
find_nuget_package(Microsoft.WSLg WSLG /build/native/bin)
find_nuget_package(vswhere VSWHERE /tools)
diff --git a/packages.config b/packages.config
index 22cc87bb7..a71b31702 100644
--- a/packages.config
+++ b/packages.config
@@ -16,6 +16,7 @@
+
diff --git a/src/linux/bpf/CMakeLists.txt b/src/linux/bpf/CMakeLists.txt
new file mode 100644
index 000000000..e069def24
--- /dev/null
+++ b/src/linux/bpf/CMakeLists.txt
@@ -0,0 +1,124 @@
+# Copyright (C) Microsoft Corporation. All rights reserved.
+#
+# CMakeLists.txt for building BPF skeleton headers.
+# This is intended to be run on a Linux build machine, NOT as part of the main Windows build.
+#
+# Required parameters:
+# -DKERNEL_IMAGE_PATH= Path to the WSL kernel binary (bzImage/Image) with embedded BTF
+# -DBPFTOOL_PATH= Path to the bpftool binary
+#
+# Optional parameters:
+# -DCLANG_BPF= Path to clang (default: clang)
+# -DOUTPUT_DIR= Output directory for generated headers (default: ${CMAKE_BINARY_DIR}/output)
+#
+# Usage:
+# cmake -S . -B build -DKERNEL_IMAGE_PATH=/path/to/bzImage -DBPFTOOL_PATH=/usr/sbin/bpftool
+# cmake --build build
+#
+# Output:
+# ${OUTPUT_DIR}/bind_monitor.skel.h - BPF skeleton header (embed in NuGet)
+# ${OUTPUT_DIR}/bind_monitor.h - Shared event struct header (embed in NuGet)
+
+cmake_minimum_required(VERSION 3.20)
+project(wsl-bpf LANGUAGES NONE)
+
+# Validate required parameters
+if(NOT DEFINED KERNEL_IMAGE_PATH)
+ message(FATAL_ERROR "KERNEL_IMAGE_PATH is required. Pass -DKERNEL_IMAGE_PATH=")
+endif()
+
+if(NOT EXISTS "${KERNEL_IMAGE_PATH}")
+ message(FATAL_ERROR "Kernel image not found: ${KERNEL_IMAGE_PATH}")
+endif()
+
+if(NOT DEFINED BPFTOOL_PATH)
+ message(FATAL_ERROR "BPFTOOL_PATH is required. Pass -DBPFTOOL_PATH=")
+endif()
+
+if(NOT EXISTS "${BPFTOOL_PATH}")
+ message(FATAL_ERROR "bpftool not found: ${BPFTOOL_PATH}")
+endif()
+
+# Optional parameters
+if(NOT DEFINED CLANG_BPF)
+ set(CLANG_BPF "clang")
+endif()
+
+if(NOT DEFINED OUTPUT_DIR)
+ set(OUTPUT_DIR "${CMAKE_BINARY_DIR}/output")
+endif()
+
+file(MAKE_DIRECTORY "${OUTPUT_DIR}")
+
+set(SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
+set(VMLINUX_H "${CMAKE_BINARY_DIR}/vmlinux.h")
+set(BPF_OBJ "${CMAKE_BINARY_DIR}/bind_monitor.bpf.o")
+set(SKEL_H "${OUTPUT_DIR}/bind_monitor.skel.h")
+set(EVENT_H "${OUTPUT_DIR}/bind_monitor.h")
+
+# Detect target architecture for BPF
+execute_process(
+ COMMAND uname -m
+ OUTPUT_VARIABLE HOST_ARCH
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+
+if(HOST_ARCH STREQUAL "x86_64")
+ set(BPF_TARGET_ARCH "x86")
+elseif(HOST_ARCH STREQUAL "aarch64")
+ set(BPF_TARGET_ARCH "arm64")
+else()
+ message(FATAL_ERROR "Unsupported architecture: ${HOST_ARCH}")
+endif()
+
+# Locate libbpf headers (needed for BPF compilation)
+find_path(LIBBPF_INCLUDE_DIR "bpf/bpf_helpers.h"
+ PATHS /usr/include /usr/local/include
+ REQUIRED)
+
+# Step 1: Generate vmlinux.h from kernel image
+add_custom_command(
+ OUTPUT "${VMLINUX_H}"
+ COMMAND bash "${SRC_DIR}/generate-vmlinux-header.sh" "${KERNEL_IMAGE_PATH}" "${VMLINUX_H}"
+ DEPENDS "${KERNEL_IMAGE_PATH}" "${SRC_DIR}/generate-vmlinux-header.sh"
+ COMMENT "Generating vmlinux.h from kernel image"
+ VERBATIM)
+
+# Step 2: Compile BPF program
+add_custom_command(
+ OUTPUT "${BPF_OBJ}"
+ COMMAND "${CLANG_BPF}" -O2 -g -target bpf
+ -D__TARGET_ARCH_${BPF_TARGET_ARCH}
+ -I "${CMAKE_BINARY_DIR}"
+ -I "${LIBBPF_INCLUDE_DIR}"
+ -I "${SRC_DIR}"
+ -c "${SRC_DIR}/bind_monitor.bpf.c"
+ -o "${BPF_OBJ}"
+ DEPENDS "${SRC_DIR}/bind_monitor.bpf.c" "${SRC_DIR}/bind_monitor.h" "${VMLINUX_H}"
+ COMMENT "Compiling bind_monitor.bpf.c"
+ VERBATIM)
+
+# Step 3: Generate skeleton header
+add_custom_command(
+ OUTPUT "${SKEL_H}"
+ COMMAND "${BPFTOOL_PATH}" gen skeleton "${BPF_OBJ}" > "${SKEL_H}"
+ DEPENDS "${BPF_OBJ}"
+ COMMENT "Generating bind_monitor.skel.h"
+ VERBATIM)
+
+# Step 4: Copy shared header to output
+add_custom_command(
+ OUTPUT "${EVENT_H}"
+ COMMAND ${CMAKE_COMMAND} -E copy "${SRC_DIR}/bind_monitor.h" "${EVENT_H}"
+ DEPENDS "${SRC_DIR}/bind_monitor.h"
+ COMMENT "Copying bind_monitor.h to output"
+ VERBATIM)
+
+# Main target
+add_custom_target(bpf_headers ALL DEPENDS "${SKEL_H}" "${EVENT_H}")
+
+message(STATUS "BPF build configuration:")
+message(STATUS " Kernel image: ${KERNEL_IMAGE_PATH}")
+message(STATUS " bpftool: ${BPFTOOL_PATH}")
+message(STATUS " clang: ${CLANG_BPF}")
+message(STATUS " Target arch: ${BPF_TARGET_ARCH}")
+message(STATUS " Output dir: ${OUTPUT_DIR}")
diff --git a/src/linux/bpf/bind_monitor.bpf.c b/src/linux/bpf/bind_monitor.bpf.c
new file mode 100644
index 000000000..4a2eaa586
--- /dev/null
+++ b/src/linux/bpf/bind_monitor.bpf.c
@@ -0,0 +1,78 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#include "vmlinux.h"
+#include
+#include
+#include
+#include
+#include "bind_monitor.h"
+
+struct {
+ __uint(type, BPF_MAP_TYPE_RINGBUF);
+ __uint(max_entries, BIND_MONITOR_RINGBUF_SIZE);
+} events SEC(".maps");
+
+static __always_inline int emit_bind_event(struct sock *sk, int ret)
+{
+ struct bind_event *e;
+
+ if (!sk || ret != 0)
+ return 0;
+
+ __u16 family = BPF_CORE_READ(sk, __sk_common.skc_family);
+ if (family != AF_INET && family != AF_INET6)
+ return 0;
+
+ __u16 protocol = BPF_CORE_READ(sk, sk_protocol);
+ if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP)
+ return 0;
+
+ e = bpf_ringbuf_reserve(&events, sizeof(*e), 0);
+ if (!e)
+ return 0;
+
+ e->family = family;
+ e->protocol = protocol;
+ e->port = BPF_CORE_READ(sk, __sk_common.skc_num);
+ e->pad = 0;
+
+ if (family == AF_INET)
+ {
+ struct inet_sock *inet = (struct inet_sock *)sk;
+ e->addr4 = BPF_CORE_READ(inet, inet_saddr);
+ __builtin_memset(e->addr6, 0, sizeof(e->addr6));
+ }
+ else
+ {
+ struct ipv6_pinfo *pinet6;
+ e->addr4 = 0;
+ pinet6 = BPF_CORE_READ((struct inet_sock *)sk, pinet6);
+ if (pinet6)
+ BPF_CORE_READ_INTO(e->addr6, pinet6, saddr.in6_u.u6_addr8);
+ else
+ __builtin_memset(e->addr6, 0, sizeof(e->addr6));
+ }
+
+ bpf_ringbuf_submit(e, 0);
+ return 0;
+}
+
+// fexit/inet_bind: called after inet_bind() returns.
+SEC("fexit/inet_bind")
+int BPF_PROG(fexit_inet_bind, struct socket *sock, struct sockaddr *uaddr,
+ int addr_len, int ret)
+{
+ struct sock *sk = BPF_CORE_READ(sock, sk);
+ return emit_bind_event(sk, ret);
+}
+
+// fexit/inet6_bind: called after inet6_bind() returns.
+SEC("fexit/inet6_bind")
+int BPF_PROG(fexit_inet6_bind, struct socket *sock, struct sockaddr *uaddr,
+ int addr_len, int ret)
+{
+ struct sock *sk = BPF_CORE_READ(sock, sk);
+ return emit_bind_event(sk, ret);
+}
+
+char LICENSE[] SEC("license") = "GPL";
diff --git a/src/linux/bpf/bind_monitor.h b/src/linux/bpf/bind_monitor.h
new file mode 100644
index 000000000..6138e4a48
--- /dev/null
+++ b/src/linux/bpf/bind_monitor.h
@@ -0,0 +1,35 @@
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#pragma once
+
+#ifndef __u8
+typedef unsigned char __u8;
+typedef unsigned short __u16;
+typedef unsigned int __u32;
+typedef unsigned long long __u64;
+#endif
+
+#ifndef AF_INET
+#define AF_INET 2
+#endif
+#ifndef AF_INET6
+#define AF_INET6 10
+#endif
+
+#ifndef IPPROTO_TCP
+#define IPPROTO_TCP 6
+#endif
+#ifndef IPPROTO_UDP
+#define IPPROTO_UDP 17
+#endif
+
+#define BIND_MONITOR_RINGBUF_SIZE (1 << 16) /* 64 KB */
+
+struct bind_event {
+ __u32 family; /* AF_INET or AF_INET6 */
+ __u32 protocol; /* IPPROTO_TCP or IPPROTO_UDP */
+ __u16 port; /* host byte order */
+ __u16 pad;
+ __u32 addr4; /* IPv4 address (network byte order) */
+ __u8 addr6[16]; /* IPv6 address */
+};
diff --git a/src/linux/bpf/generate-vmlinux-header.sh b/src/linux/bpf/generate-vmlinux-header.sh
new file mode 100644
index 000000000..5c7a73a74
--- /dev/null
+++ b/src/linux/bpf/generate-vmlinux-header.sh
@@ -0,0 +1,111 @@
+#!/bin/bash
+set -e
+
+if [ $# -lt 2 ]; then
+ echo "Usage: $0 "
+ exit 1
+fi
+
+BZIMAGE="$1"
+OUTPUT="$2"
+BPFTOOL="${BPFTOOL:-bpftool}"
+
+TMPDIR=$(mktemp -d)
+trap 'rm -rf "$TMPDIR"' EXIT
+
+VMLINUX="$TMPDIR/vmlinux"
+
+echo "Extracting vmlinux from $BZIMAGE..."
+
+# First, try bpftool directly (works if input is already an ELF with BTF)
+if "$BPFTOOL" btf dump file "$BZIMAGE" format c > "$OUTPUT" 2>/dev/null; then
+ LINES=$(wc -l < "$OUTPUT")
+ if [ "$LINES" -gt 100 ]; then
+ echo "Done. Generated $OUTPUT ($LINES lines)"
+ exit 0
+ fi
+fi
+
+# Try to find an ELF embedded in the image (e.g., ARM64 Image)
+ELF_OFFSET=$(binwalk -y elf "$BZIMAGE" 2>/dev/null | grep -oP '^\d+' | head -1) || true
+if [ -n "$ELF_OFFSET" ]; then
+ echo "Found ELF at offset $ELF_OFFSET"
+ tail -c +$((ELF_OFFSET + 1)) "$BZIMAGE" > "$VMLINUX"
+ if file "$VMLINUX" | grep -q 'ELF' && "$BPFTOOL" btf dump file "$VMLINUX" format c > "$OUTPUT" 2>/dev/null; then
+ LINES=$(wc -l < "$OUTPUT")
+ if [ "$LINES" -gt 100 ]; then
+ echo "Done. Generated $OUTPUT ($LINES lines)"
+ exit 0
+ fi
+ fi
+fi
+
+# Try to find raw BTF data in the image (ARM64 Image stores BTF as raw data)
+BTF_RAW="$TMPDIR/btf.raw"
+BTF_OFFSET=$(python3 -c "
+import struct, sys
+data = open(sys.argv[1], 'rb').read()
+magic = b'\x9f\xeb'
+idx = 0
+while True:
+ idx = data.find(magic, idx)
+ if idx == -1: break
+ if data[idx+2] == 1: # version 1
+ hdr_len = struct.unpack_from(' 1000:
+ print(f'{idx} {total}')
+ break
+ idx += 1
+" "$BZIMAGE" 2>/dev/null) || true
+
+if [ -n "$BTF_OFFSET" ]; then
+ BTF_OFF=$(echo "$BTF_OFFSET" | awk '{print $1}')
+ BTF_SIZE=$(echo "$BTF_OFFSET" | awk '{print $2}')
+ echo "Found raw BTF at offset $BTF_OFF ($BTF_SIZE bytes)"
+ tail -c +$((BTF_OFF + 1)) "$BZIMAGE" | head -c "$BTF_SIZE" > "$BTF_RAW"
+ if "$BPFTOOL" btf dump file "$BTF_RAW" format c > "$OUTPUT" 2>/dev/null; then
+ LINES=$(wc -l < "$OUTPUT")
+ if [ "$LINES" -gt 100 ]; then
+ echo "Done. Generated $OUTPUT ($LINES lines)"
+ exit 0
+ fi
+ fi
+fi
+
+# Find gzip offset and decompress to get the vmlinux ELF
+GZIP_OFFSET=$(binwalk -y gzip "$BZIMAGE" 2>/dev/null | grep -oP '^\d+' | head -1) || true
+
+if [ -z "$GZIP_OFFSET" ]; then
+ echo "Error: no gzip or ELF payload found in $BZIMAGE" >&2
+ exit 1
+fi
+
+echo "Found gzip payload at offset $GZIP_OFFSET"
+tail -c +$((GZIP_OFFSET + 1)) "$BZIMAGE" | zcat > "$VMLINUX" 2>/dev/null || true
+
+if ! file "$VMLINUX" | grep -q 'ELF'; then
+ # The gzip payload might contain another layer; try to find ELF inside
+ INNER="$TMPDIR/inner"
+ tail -c +$((GZIP_OFFSET + 1)) "$BZIMAGE" | zcat 2>/dev/null > "$INNER" || true
+
+ # Search for ELF magic in decompressed data
+ ELF_OFFSET=$(grep -a -b -o -P '\x7fELF' "$INNER" 2>/dev/null | head -1 | cut -d: -f1) || true
+ if [ -n "$ELF_OFFSET" ]; then
+ tail -c +$((ELF_OFFSET + 1)) "$INNER" > "$VMLINUX"
+ fi
+fi
+
+if ! file "$VMLINUX" | grep -q 'ELF'; then
+ echo "Error: could not extract a valid ELF vmlinux from $BZIMAGE" >&2
+ exit 1
+fi
+
+echo "Extracted vmlinux: $(file "$VMLINUX")"
+echo "Generating vmlinux.h with bpftool..."
+"$BPFTOOL" btf dump file "$VMLINUX" format c > "$OUTPUT"
+
+LINES=$(wc -l < "$OUTPUT")
+echo "Done. Generated $OUTPUT ($LINES lines)"
diff --git a/src/linux/bpf/listen_monitor.bpf.c b/src/linux/bpf/listen_monitor.bpf.c
new file mode 100644
index 000000000..b89bada00
--- /dev/null
+++ b/src/linux/bpf/listen_monitor.bpf.c
@@ -0,0 +1,123 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#include "vmlinux.h"
+#include
+#include
+#include
+#include
+#include "listen_monitor.h"
+
+struct {
+ __uint(type, BPF_MAP_TYPE_RINGBUF);
+ __uint(max_entries, LISTEN_MONITOR_RINGBUF_SIZE);
+} events SEC(".maps");
+
+// Store listen info at listen() time so release can emit correct values
+// even if the kernel has already cleared the socket fields.
+struct {
+ __uint(type, BPF_MAP_TYPE_HASH);
+ __type(key, __u64); // sock pointer as key
+ __type(value, struct listen_event); // snapshot from listen time
+ __uint(max_entries, 65536);
+} listening_sockets SEC(".maps");
+
+static __always_inline int read_sock_info(struct sock *sk, struct listen_event *info)
+{
+ info->family = BPF_CORE_READ(sk, __sk_common.skc_family);
+ if (info->family != AF_INET && info->family != AF_INET6)
+ return -1;
+
+ info->port = BPF_CORE_READ(sk, __sk_common.skc_num);
+
+ if (info->family == AF_INET)
+ {
+ struct inet_sock *inet = (struct inet_sock *)sk;
+ info->addr4 = BPF_CORE_READ(inet, inet_saddr);
+ __builtin_memset(info->addr6, 0, sizeof(info->addr6));
+ }
+ else
+ {
+ struct ipv6_pinfo *pinet6;
+ info->addr4 = 0;
+ pinet6 = BPF_CORE_READ((struct inet_sock *)sk, pinet6);
+ if (pinet6)
+ BPF_CORE_READ_INTO(info->addr6, pinet6, saddr.in6_u.u6_addr8);
+ else
+ __builtin_memset(info->addr6, 0, sizeof(info->addr6));
+ }
+
+ return 0;
+}
+
+// fexit/inet_listen: called after inet_listen() returns.
+// inet_listen handles both IPv4 and IPv6.
+// int inet_listen(struct socket *sock, int backlog)
+SEC("fexit/inet_listen")
+int BPF_PROG(fexit_inet_listen, struct socket *sock, int backlog, int ret)
+{
+ struct listen_event *e;
+ struct listen_event info = {};
+ __u64 sk_key;
+ struct sock *sk;
+
+ if (ret != 0)
+ return 0;
+
+ sk = BPF_CORE_READ(sock, sk);
+ if (!sk)
+ return 0;
+
+ sk_key = (__u64)sk;
+
+ if (read_sock_info(sk, &info) < 0)
+ return 0;
+
+ info.is_listen = 1;
+
+ bpf_map_update_elem(&listening_sockets, &sk_key, &info, BPF_ANY);
+
+ e = bpf_ringbuf_reserve(&events, sizeof(*e), 0);
+ if (!e)
+ return 0;
+
+ *e = info;
+ bpf_ringbuf_submit(e, 0);
+ return 0;
+}
+
+// fentry/inet_release: called when a socket is being closed.
+// void inet_release(struct socket *sock)
+SEC("fentry/inet_release")
+int BPF_PROG(fentry_inet_release, struct socket *sock)
+{
+ struct listen_event *e;
+ struct listen_event *info;
+ struct sock *sk;
+ __u64 sk_key;
+
+ sk = BPF_CORE_READ(sock, sk);
+ if (!sk)
+ return 0;
+
+ sk_key = (__u64)sk;
+
+ info = bpf_map_lookup_elem(&listening_sockets, &sk_key);
+ if (!info)
+ return 0;
+
+ e = bpf_ringbuf_reserve(&events, sizeof(*e), 0);
+ if (!e)
+ {
+ bpf_map_delete_elem(&listening_sockets, &sk_key);
+ return 0;
+ }
+
+ *e = *info;
+ e->is_listen = 0;
+
+ bpf_map_delete_elem(&listening_sockets, &sk_key);
+ bpf_ringbuf_submit(e, 0);
+ return 0;
+}
+
+char LICENSE[] SEC("license") = "GPL";
diff --git a/src/linux/bpf/listen_monitor.h b/src/linux/bpf/listen_monitor.h
new file mode 100644
index 000000000..58e684dcd
--- /dev/null
+++ b/src/linux/bpf/listen_monitor.h
@@ -0,0 +1,28 @@
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#pragma once
+
+#ifndef __u8
+typedef unsigned char __u8;
+typedef unsigned short __u16;
+typedef unsigned int __u32;
+typedef unsigned long long __u64;
+#endif
+
+#ifndef AF_INET
+#define AF_INET 2
+#endif
+#ifndef AF_INET6
+#define AF_INET6 10
+#endif
+
+#define LISTEN_MONITOR_RINGBUF_SIZE (1 << 16) /* 64 KB */
+
+struct listen_event {
+ __u32 family; /* AF_INET or AF_INET6 */
+ __u16 port; /* host byte order */
+ __u8 is_listen; /* 1 = started listening, 0 = stopped listening */
+ __u8 pad;
+ __u32 addr4; /* IPv4 address (network byte order) */
+ __u8 addr6[16]; /* IPv6 address */
+};
diff --git a/src/linux/init/CMakeLists.txt b/src/linux/init/CMakeLists.txt
index d5dee8005..ef19f6812 100644
--- a/src/linux/init/CMakeLists.txt
+++ b/src/linux/init/CMakeLists.txt
@@ -46,7 +46,14 @@ set(HEADERS
wslpath.h)
set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${CMAKE_CURRENT_LIST_DIR}/../netlinkutil")
+set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${WSL_BPF_SOURCE_DIR}/bpf/${TARGET_PLATFORM}/include")
+set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/include")
set(INIT_LIBRARIES ${COMMON_LINUX_LINK_LIBRARIES} netlinkutil plan9 mountutil configfile)
+set(INIT_EXTRA_LIBS
+ ${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/lib/libbpf.a
+ ${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/lib/libelf.a
+ ${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/lib/libz.a)
+set(LINUX_LDFLAGS ${LINUX_LDFLAGS} ${INIT_EXTRA_LIBS})
add_linux_executable(init "${SOURCES}" "${HEADERS};${COMMON_LINUX_HEADERS}" "${INIT_LIBRARIES}")
add_dependencies(init localization)
diff --git a/src/linux/init/GnsPortTracker.cpp b/src/linux/init/GnsPortTracker.cpp
index 669582a14..dfa0ca145 100644
--- a/src/linux/init/GnsPortTracker.cpp
+++ b/src/linux/init/GnsPortTracker.cpp
@@ -3,16 +3,19 @@
#include
#include
#include
-#include /* Definition of AUDIT_* constants */
+#include
#include
#include
#include
#include
#include
-#include "common.h" // Needs to be included before sal.h before of __reserved macro
+#include
+#include "common.h"
#include "NetlinkTransactionError.h"
#include "GnsPortTracker.h"
#include "lxinitshared.h"
+#include "bind_monitor.h"
+#include "bind_monitor.skel.h"
constexpr size_t c_bind_timeout_seconds = 60;
constexpr auto c_sock_diag_refresh_delay = std::chrono::milliseconds(500);
@@ -20,8 +23,14 @@ constexpr auto c_sock_diag_poll_timeout = std::chrono::milliseconds(10);
constexpr auto c_bpf_poll_timeout = std::chrono::milliseconds(500);
GnsPortTracker::GnsPortTracker(
- std::shared_ptr hvSocketChannel, NetlinkChannel&& netlinkChannel, std::shared_ptr seccompDispatcher) :
- m_hvSocketChannel(std::move(hvSocketChannel)), m_channel(std::move(netlinkChannel)), m_seccompDispatcher(seccompDispatcher)
+ std::shared_ptr hvSocketChannel,
+ NetlinkChannel&& netlinkChannel,
+ std::shared_ptr seccompDispatcher,
+ std::shared_ptr channelMutex) :
+ m_hvSocketChannel(std::move(hvSocketChannel)),
+ m_channelMutex(std::move(channelMutex)),
+ m_channel(std::move(netlinkChannel)),
+ m_seccompDispatcher(seccompDispatcher)
{
m_networkNamespace = std::filesystem::read_symlink("/proc/self/ns/net").string();
}
@@ -65,15 +74,116 @@ int GnsPortTracker::ProcessSecCompNotification(seccomp_notif* notification)
return m_reply.get();
}
-void GnsPortTracker::Run()
+void GnsPortTracker::OnBindEvent(void* Data)
{
- // This method consumes seccomp notifications and allows / disallows port allocations
- // depending on wsl core's response.
- // After dealing with a notification it also looks at the bound ports list to check
- // for port deallocation
+ const auto* event = static_cast(Data);
+
+ in6_addr address = {};
+ if (event->family == AF_INET)
+ {
+ address.s6_addr32[0] = event->addr4;
+ }
+ else if (event->family == AF_INET6)
+ {
+ memcpy(address.s6_addr32, event->addr6, sizeof(address.s6_addr32));
+ }
+ else
+ {
+ return;
+ }
+
+ PortAllocation allocation(event->port, event->family, event->protocol, address);
+
+ // Check if this port is already tracked by the seccomp path (non-zero bind).
+ // If it is, this is not a port-0 bind — skip it.
+ {
+ std::lock_guard lock(m_portsMutex);
+ if (m_allocatedPorts.contains(allocation))
+ {
+ return;
+ }
+ }
+
+ // This is a port-0 bind that seccomp let through. Report it to the host.
+ const auto result = HandleRequest(allocation);
+ if (result == 0)
+ {
+ TrackPort(std::move(allocation));
+ }
+ else
+ {
+ GNS_LOG_ERROR(
+ "Failed to register port-0 bind: family ({}) port ({}) protocol ({}), error {}",
+ allocation.Family,
+ allocation.Port,
+ allocation.Protocol,
+ result);
+ }
+}
+void GnsPortTracker::Run()
+{
std::thread{std::bind(&GnsPortTracker::RunPortRefresh, this)}.detach();
- std::thread{std::bind(&GnsPortTracker::RunDeferredResolve, this)}.detach();
+
+ // Start eBPF bind monitor on a background thread to handle port-0 bind resolution.
+ std::thread{[this]() {
+ try
+ {
+ auto* skel = bind_monitor_bpf__open_and_load();
+ if (!skel)
+ {
+ LOG_ERROR("Failed to open/load bind monitor BPF program, {}", errno);
+ return;
+ }
+
+ auto destroySkel = wil::scope_exit([&] { bind_monitor_bpf__destroy(skel); });
+
+ if (bind_monitor_bpf__attach(skel) != 0)
+ {
+ LOG_ERROR("Failed to attach bind monitor BPF program, {}", errno);
+ return;
+ }
+
+ auto onEvent = [](void* ctx, void* data, size_t dataSz) noexcept -> int {
+ try
+ {
+ static_cast(ctx)->OnBindEvent(data);
+ }
+ catch (...)
+ {
+ LOG_CAUGHT_EXCEPTION_MSG("Error processing bind monitor event");
+ }
+ return 0;
+ };
+
+ auto* rb = ring_buffer__new(bpf_map__fd(skel->maps.events), onEvent, this, nullptr);
+ if (!rb)
+ {
+ LOG_ERROR("Failed to create bind monitor ring buffer, {}", errno);
+ return;
+ }
+
+ auto destroyRb = wil::scope_exit([&] { ring_buffer__free(rb); });
+
+ GNS_LOG_INFO("BPF bind monitor for port-0 resolution attached and running");
+
+ for (;;)
+ {
+ int err = ring_buffer__poll(rb, -1);
+ if (err == -EINTR)
+ {
+ continue;
+ }
+
+ if (err < 0)
+ {
+ LOG_ERROR("bind monitor ring_buffer__poll failed, {}", err);
+ return;
+ }
+ }
+ }
+ CATCH_LOG()
+ }}.detach();
auto future = std::make_optional(m_allocatedPortsRefresh.get_future());
std::optional refreshResult;
@@ -117,19 +227,8 @@ void GnsPortTracker::Run()
GNS_LOG_ERROR("Failed to complete bind request, {}", e.what());
}
- if (bindCall->PortZeroBind.has_value())
- {
- try
- {
- std::lock_guard lock(m_deferredMutex);
- m_deferredQueue.push_back(std::move(bindCall->PortZeroBind.value()));
- m_deferredCv.notify_one();
- }
- catch (const std::exception& e)
- {
- GNS_LOG_ERROR("Failed to queue port-0 bind for deferred resolution, {}", e.what());
- }
- }
+ // Port-0 binds are now handled by the eBPF callback (OnBindEvent),
+ // so no deferred resolution is needed here.
}
// If bindCall is empty, then the read() timed out. Look for any closed port
@@ -139,46 +238,21 @@ void GnsPortTracker::Run()
future.reset();
m_allocatedPortsRefresh = {};
- // If this loop's iteration had a bind call, it's possible that RefreshAllocatedPort
- // was called before the bind called was processed. Make sure that the port list
- // is up to date (If this is called, the next block will schedule another refresh)
if (!bindCall.has_value())
{
+ std::lock_guard lock(m_portsMutex);
OnRefreshAllocatedPorts(refreshResult->Ports, refreshResult->Timestamp);
}
}
- // Process any port-0 binds that the background thread has resolved.
- std::deque resolved;
- {
- std::lock_guard lock(m_resolvedMutex);
- resolved.swap(m_resolvedQueue);
- }
- for (auto& allocation : resolved)
- {
- const auto result = HandleRequest(allocation);
- if (result == 0)
- {
- TrackPort(std::move(allocation));
- }
- else
- {
- GNS_LOG_ERROR(
- "Failed to register resolved port-0 bind: family ({}) port ({}) protocol ({}), error {}",
- allocation.Family,
- allocation.Port,
- allocation.Protocol,
- result);
- }
- }
-
// Only look at bound ports if there's something to deallocate to avoid wasting cycles
if (refreshResult.has_value())
{
+ std::lock_guard lock(m_portsMutex);
if (!m_allocatedPorts.empty())
{
future = m_allocatedPortsRefresh.get_future();
- refreshResult->Resume(); // This will resume the sock_diag thread
+ refreshResult->Resume();
refreshResult.reset();
}
}
@@ -244,7 +318,9 @@ std::set GnsPortTracker::ListAllocatedPorts()
void GnsPortTracker::OnRefreshAllocatedPorts(const std::set& Ports, time_t Timestamp)
{
- // Because there's no way to get notified when the bind() call actually completes, it' possible
+ // m_portsMutex must be held by caller.
+ //
+ // Because there's no way to get notified when the bind() call actually completes, it's possible
// that this method is called before the bind() completion and so the port allocation may not be visible yet.
// To avoid deallocating ports that simply haven't been done allocating yet, m_allocatedPorts stores a timeout
// that prevents deallocating the port unless:
@@ -296,6 +372,7 @@ int GnsPortTracker::RequestPort(const PortAllocation& Port, bool allocate)
static_assert(sizeof(request.Address32) == sizeof(Port.Address.s6_addr32));
memcpy(request.Address32, Port.Address.s6_addr32, sizeof(request.Address32));
+ std::lock_guard lock(*m_channelMutex);
const auto& response = m_hvSocketChannel->Transaction(request);
return response.Result;
@@ -303,14 +380,13 @@ int GnsPortTracker::RequestPort(const PortAllocation& Port, bool allocate)
int GnsPortTracker::HandleRequest(const PortAllocation& Port)
{
- // If the port is already allocated, let the call go through and the kernel will
- // decide if bind() should succeed or not
- // Note: Returning 0 will also cause the port's timeout to be updated
-
- if (m_allocatedPorts.contains(Port))
{
- GNS_LOG_INFO("Request for a port that's already reserved (family {}, port {}, protocol {})", Port.Family, Port.Port, Port.Protocol);
- return 0;
+ std::lock_guard lock(m_portsMutex);
+ if (m_allocatedPorts.contains(Port))
+ {
+ GNS_LOG_INFO("Request for a port that's already reserved (family {}, port {}, protocol {})", Port.Family, Port.Port, Port.Protocol);
+ return 0;
+ }
}
// Ask the host for this port otherwise
@@ -322,7 +398,6 @@ int GnsPortTracker::HandleRequest(const PortAllocation& Port)
std::optional GnsPortTracker::ReadNextRequest()
{
- // Read the call information
auto request_value = m_request.try_get(c_bpf_poll_timeout);
if (!request_value.has_value())
{
@@ -331,11 +406,6 @@ std::optional GnsPortTracker::ReadNextRequest()
auto callInfo = request_value.value();
- // This logic needs to be defensive because the calling process is blocked until
- // CompleteRequest() is called, so if the call information can't be processed because
- // the caller has done something wrong (bad pointer, fd, or protocol), just let it go through
- // and the kernel will fail it
-
try
{
return GetCallInfo(callInfo.id, callInfo.pid, callInfo.data.arch, callInfo.data.nr, gsl::make_span(callInfo.data.args));
@@ -343,7 +413,7 @@ std::optional GnsPortTracker::ReadNextRequest()
catch (const std::exception& e)
{
GNS_LOG_ERROR("Failed to read bind() call info with ID {} for pid {}, {}", callInfo.id, callInfo.pid, e.what());
- return {{{}, {}, callInfo.id}};
+ return {{{}, false, callInfo.id}};
}
}
@@ -353,14 +423,14 @@ std::optional GnsPortTracker::GetCallInfo(
auto ParseSocket = [&](int Socket, size_t AddressPtr, size_t AddressLength) -> std::optional {
if (AddressLength < sizeof(sockaddr))
{
- return {{{}, {}, CallId}}; // Invalid sockaddr. Let it go through.
+ return {{{}, false, CallId}};
}
auto networkNamespace = std::filesystem::read_symlink(std::format("/proc/{}/ns/net", Pid)).string();
if (networkNamespace != m_networkNamespace)
{
GNS_LOG_INFO("Skipping bind() call for pid {} in network namespace {}", Pid, networkNamespace.c_str());
- return {{{}, {}, CallId}}; // Different network namespace. Let it go through.
+ return {{{}, false, CallId}};
}
auto processMemory = m_seccompDispatcher->ReadProcessMemory(CallId, Pid, AddressPtr, AddressLength);
@@ -374,40 +444,19 @@ std::optional GnsPortTracker::GetCallInfo(
if ((address.sa_family != AF_INET && address.sa_family != AF_INET6) ||
(address.sa_family == AF_INET6 && AddressLength < sizeof(sockaddr_in6)))
{
- return {{{}, {}, CallId}}; // This is a non IP call, or invalid sockaddr_in6. Let it go through
+ return {{{}, false, CallId}};
}
- // Read the port. The port *happens* to be in the same spot in memory for both sockaddr_in
- // and sockaddr_in6. To avoid a second memory read, we take advantage of this fact to fetch
- // the port from the currently read memory, regardless of the address family.
static_assert(sizeof(sockaddr_in) <= sizeof(sockaddr));
const auto* inAddr = reinterpret_cast(&address);
in_port_t port = ntohs(inAddr->sin_port);
if (port == 0)
{
- // Port 0 means the kernel will assign an ephemeral port. We can't know
- // the port until after the bind() completes, so duplicate the socket fd
- // now (while the process is still stopped by seccomp) and defer the
- // getsockname() lookup to after CompleteRequest() unblocks it.
- try
- {
- const int protocol = GetSocketProtocol(Pid, Socket);
- auto dupFd = DuplicateSocketFd(Pid, Socket);
- if (!dupFd)
- {
- return {{{}, {}, CallId}};
- }
- if (!m_seccompDispatcher->ValidateCookie(CallId))
- {
- return {{{}, {}, CallId}};
- }
- return {{{}, DeferredPortLookup{Pid, std::move(dupFd), protocol}, CallId}};
- }
- catch (const std::exception&)
- {
- return {{{}, {}, CallId}}; // Can't determine protocol, just let it through
- }
+ // Port 0 means the kernel will assign an ephemeral port.
+ // The eBPF fexit handler (OnBindEvent) will resolve the actual port
+ // after bind() completes. Just let the call through.
+ return {{{}, true, CallId}};
}
in6_addr storedAddress = {};
@@ -422,35 +471,26 @@ std::optional GnsPortTracker::GetCallInfo(
memcpy(storedAddress.s6_addr32, inAddr6->sin6_addr.s6_addr32, sizeof(storedAddress.s6_addr32));
}
- // It's possible that the calling process lied and passed a sockaddr that
- // doesn't match the underlying socket family or a bad fd. If that's the case,
- // then GetSocketProtocol() will throw
const int protocol = GetSocketProtocol(Pid, Socket);
- // As GetSocketProtocol interacts with /proc/, to avoid TOCTOU races we need to
- // verify that call is still valid (call id is the same thing as cookie)
if (!m_seccompDispatcher->ValidateCookie(CallId))
{
throw RuntimeErrorWithSourceLocation(std::format("Invalid call id {}", CallId));
}
- return {{{PortAllocation(port, address.sa_family, protocol, storedAddress)}, {}, CallId}};
+ return {{{PortAllocation(port, address.sa_family, protocol, storedAddress)}, false, CallId}};
};
#ifdef __x86_64__
if (Arch & __AUDIT_ARCH_64BIT)
{
return ParseSocket(Arguments[0], Arguments[1], Arguments[2]);
}
- // Note: 32bit on x86_64 uses the __NR_socketcall with the first argument
- // set to SYS_BIND to make bind system call and the second argument is
- // a pointer to a block of memory containing the original arguments.
else
{
if (Arguments[0] != SYS_BIND)
{
- return {{{}, {}, CallId}}; // Not a bind call, just let the call go through
+ return {{{}, false, CallId}};
}
- // Grab the first 3 parameters
auto processMemory = m_seccompDispatcher->ReadProcessMemory(CallId, Pid, Arguments[1], sizeof(uint32_t) * 3);
if (!processMemory.has_value())
{
@@ -506,32 +546,12 @@ int GnsPortTracker::GetSocketProtocol(int pid, int fd)
throw RuntimeErrorWithSourceLocation(std::format("Unexpected IP socket protocol: {}", protocol));
}
-wil::unique_fd GnsPortTracker::DuplicateSocketFd(pid_t Pid, int SocketFd)
-{
- // Duplicate the socket fd from the target process into our address space.
- // We cannot use open("/proc/pid/fd/N") for sockets because the symlink target
- // (socket:[inode]) is not a valid filesystem path. Use pidfd_getfd() instead.
- wil::unique_fd pidFd(static_cast(syscall(SYS_pidfd_open, Pid, 0u)));
- if (!pidFd)
- {
- GNS_LOG_INFO("Port-0 bind: pidfd_open failed for pid {} (errno {})", Pid, errno);
- return {};
- }
-
- wil::unique_fd dupFd(static_cast(syscall(SYS_pidfd_getfd, pidFd.get(), SocketFd, 0u)));
- if (!dupFd)
- {
- GNS_LOG_INFO("Port-0 bind: pidfd_getfd failed for pid {} fd {} (errno {})", Pid, SocketFd, errno);
- }
-
- return dupFd;
-}
-
void GnsPortTracker::TrackPort(PortAllocation allocation)
try
{
// Use insert_or_assign so the deallocation timeout is refreshed if the same
// port key is already present (emplace would silently keep the old entry).
+ std::lock_guard lock(m_portsMutex);
m_allocatedPorts.insert_or_assign(std::move(allocation), std::make_optional(time(nullptr) + c_bind_timeout_seconds));
}
catch (const std::exception& e)
@@ -539,104 +559,6 @@ catch (const std::exception& e)
GNS_LOG_ERROR("Failed to track port allocation, {}", e.what());
}
-void GnsPortTracker::RunDeferredResolve()
-{
- UtilSetThreadName("GnsPortZero");
-
- for (;;)
- {
- DeferredPortLookup lookup{0, {}, 0};
- {
- std::unique_lock lock(m_deferredMutex);
- m_deferredCv.wait(lock, [&] { return !m_deferredQueue.empty(); });
- lookup = std::move(m_deferredQueue.front());
- m_deferredQueue.pop_front();
- }
-
- const auto pid = lookup.Pid;
- try
- {
- ResolvePortZeroBind(std::move(lookup));
- }
- catch (const std::exception& e)
- {
- GNS_LOG_ERROR("Failed to resolve port-0 bind for pid {}, {}", pid, e.what());
- }
- }
-}
-
-void GnsPortTracker::ResolvePortZeroBind(DeferredPortLookup lookup)
-{
- // The socket fd was already duplicated (via pidfd_getfd) while the target process
- // was stopped by seccomp, so it remains valid even if the process has closed or
- // reused the original fd number.
-
- // The bind() syscall is being completed asynchronously on the seccomp dispatcher
- // thread after CompleteRequest() unblocks it. Poll getsockname() briefly until
- // the kernel assigns a port.
- constexpr int maxRetries = 25;
- constexpr auto retryDelay = std::chrono::milliseconds(100);
-
- in_port_t port = 0;
- in6_addr address = {};
- int resolvedFamily = 0;
-
- for (int attempt = 0; attempt < maxRetries; ++attempt)
- {
- if (attempt > 0)
- {
- std::this_thread::sleep_for(retryDelay);
- }
-
- sockaddr_storage storage{};
- socklen_t addrLen = sizeof(storage);
- if (getsockname(lookup.DuplicatedSocketFd.get(), reinterpret_cast(&storage), &addrLen) != 0)
- {
- GNS_LOG_ERROR("Port-0 bind: getsockname failed for pid {} (errno {})", lookup.Pid, errno);
- return;
- }
-
- resolvedFamily = static_cast(storage.ss_family);
-
- if (storage.ss_family == AF_INET)
- {
- const auto* sin = reinterpret_cast(&storage);
- port = ntohs(sin->sin_port);
- address.s6_addr32[0] = sin->sin_addr.s_addr;
- }
- else if (storage.ss_family == AF_INET6)
- {
- const auto* sin6 = reinterpret_cast(&storage);
- port = ntohs(sin6->sin6_port);
- memcpy(address.s6_addr32, sin6->sin6_addr.s6_addr32, sizeof(address.s6_addr32));
- }
- else
- {
- GNS_LOG_ERROR("Port-0 bind: unexpected address family ({}) for pid {}", resolvedFamily, lookup.Pid);
- return;
- }
-
- if (port != 0)
- {
- break;
- }
- }
-
- if (port == 0)
- {
- GNS_LOG_ERROR("Port-0 bind: kernel did not assign a port for pid {} after retries", lookup.Pid);
- return;
- }
-
- PortAllocation allocation(port, resolvedFamily, lookup.Protocol, address);
- GNS_LOG_INFO(
- "Port-0 bind resolved: family ({}) port ({}) protocol ({}) for pid {}", resolvedFamily, port, lookup.Protocol, lookup.Pid);
- {
- std::lock_guard lock(m_resolvedMutex);
- m_resolvedQueue.push_back(std::move(allocation));
- }
-}
-
std::ostream& operator<<(std::ostream& out, const GnsPortTracker::PortAllocation& entry)
{
return out << "Port=" << entry.Port << ", Family=" << entry.Family << ", Protocol=" << entry.Protocol;
diff --git a/src/linux/init/GnsPortTracker.h b/src/linux/init/GnsPortTracker.h
index a3f2945a1..2dfa2a8dd 100644
--- a/src/linux/init/GnsPortTracker.h
+++ b/src/linux/init/GnsPortTracker.h
@@ -21,17 +21,22 @@
class GnsPortTracker
{
public:
- GnsPortTracker(std::shared_ptr hvSocketChannel, NetlinkChannel&& netlinkChannel, std::shared_ptr seccompDispatcher);
+ GnsPortTracker(
+ std::shared_ptr hvSocketChannel,
+ NetlinkChannel&& netlinkChannel,
+ std::shared_ptr seccompDispatcher,
+ std::shared_ptr channelMutex);
- GnsPortTracker(const GnsPortTracker&) = delete;
- GnsPortTracker(GnsPortTracker&&) = delete;
- GnsPortTracker& operator=(const GnsPortTracker&) = delete;
- GnsPortTracker& operator=(GnsPortTracker&&) = delete;
+ NON_COPYABLE(GnsPortTracker);
+ NON_MOVABLE(GnsPortTracker);
void Run();
int ProcessSecCompNotification(seccomp_notif* notification);
+ // Called from the eBPF ring buffer callback thread to handle port-0 bind resolution.
+ void OnBindEvent(void* data);
+
struct PortAllocation
{
in6_addr Address = {};
@@ -94,27 +99,10 @@ class GnsPortTracker
}
};
- struct DeferredPortLookup
- {
- pid_t Pid;
- wil::unique_fd DuplicatedSocketFd; // Duplicated via pidfd_getfd while process was stopped
- int Protocol;
-
- DeferredPortLookup(pid_t Pid, wil::unique_fd DuplicatedSocketFd, int Protocol) :
- Pid(Pid), DuplicatedSocketFd(std::move(DuplicatedSocketFd)), Protocol(Protocol)
- {
- }
-
- DeferredPortLookup(DeferredPortLookup&&) = default;
- DeferredPortLookup& operator=(DeferredPortLookup&&) = default;
- DeferredPortLookup(const DeferredPortLookup&) = delete;
- DeferredPortLookup& operator=(const DeferredPortLookup&) = delete;
- };
-
struct BindCall
{
std::optional Request;
- std::optional PortZeroBind;
+ bool IsPortZeroBind = false;
std::uint64_t CallId;
};
@@ -144,16 +132,14 @@ class GnsPortTracker
static int GetSocketProtocol(int Pid, int Fd);
- static wil::unique_fd DuplicateSocketFd(pid_t Pid, int SocketFd);
-
- void ResolvePortZeroBind(DeferredPortLookup lookup);
-
- void RunDeferredResolve();
-
void TrackPort(PortAllocation allocation);
+ // Protects m_allocatedPorts for concurrent access from the main loop and the eBPF callback thread.
+ std::mutex m_portsMutex;
std::map> m_allocatedPorts;
+
std::shared_ptr m_hvSocketChannel;
+ std::shared_ptr m_channelMutex;
NetlinkChannel m_channel;
std::promise m_allocatedPortsRefresh;
@@ -163,15 +149,6 @@ class GnsPortTracker
std::shared_ptr m_seccompDispatcher;
std::string m_networkNamespace;
-
- std::mutex m_deferredMutex;
- std::condition_variable m_deferredCv;
- std::deque m_deferredQueue;
-
- // Resolved port-0 allocations posted by the background RunDeferredResolve thread
- // for the main Run() loop to process (keeps SocketChannel access single-threaded).
- std::mutex m_resolvedMutex;
- std::deque m_resolvedQueue;
};
std::ostream& operator<<(std::ostream& out, const GnsPortTracker::PortAllocation& portAllocation);
diff --git a/src/linux/init/localhost.cpp b/src/linux/init/localhost.cpp
index f1af32182..3d7e3c3db 100644
--- a/src/linux/init/localhost.cpp
+++ b/src/linux/init/localhost.cpp
@@ -1,6 +1,7 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
#include "common.h"
#include
+#include
#include
#include
#include
@@ -12,11 +13,12 @@
#include
#include
#include
-#include
-#include
#include
+#include
#include
+#include
+
#include "util.h"
#include "SocketChannel.h"
#include "GnsPortTracker.h"
@@ -25,8 +27,8 @@
#include "CommandLine.h"
#include "NetlinkChannel.h"
#include "NetlinkTransactionError.h"
-
-#define TCP_LISTEN 10
+#include "listen_monitor.h"
+#include "listen_monitor.skel.h"
namespace {
@@ -149,64 +151,6 @@ void ListenThread(sockaddr_vm hvSocketAddress, int listenSocket)
return;
}
-std::vector QueryListeningSockets(NetlinkChannel& channel)
-{
- std::vector sockets{};
- try
- {
- inet_diag_req_v2 message{};
- message.sdiag_protocol = IPPROTO_TCP;
- message.idiag_states = (1 << TCP_LISTEN);
-
- auto onMessage = [&](const NetlinkResponse& response) {
- for (const auto& e : response.Messages(SOCK_DIAG_BY_FAMILY))
- {
- const auto* payload = e.Payload();
- sockaddr_storage sock{};
-
- if (payload->idiag_family == AF_INET)
- {
- auto* ipv4 = reinterpret_cast(&sock);
- ipv4->sin_family = AF_INET;
- ipv4->sin_addr.s_addr = payload->id.idiag_src[0];
- ipv4->sin_port = payload->id.idiag_sport;
- }
- else if (payload->idiag_family == AF_INET6)
- {
- auto* ipv6 = reinterpret_cast(&sock);
- ipv6->sin6_family = AF_INET6;
- static_assert(sizeof(ipv6->sin6_addr.s6_addr32) == sizeof(payload->id.idiag_src));
- memcpy(ipv6->sin6_addr.s6_addr32, payload->id.idiag_src, sizeof(ipv6->sin6_addr.s6_addr32));
- ipv6->sin6_port = payload->id.idiag_sport;
- }
-
- sockets.emplace_back(sock);
- }
- };
-
- // Query IPv4 listening sockets.
- {
- message.sdiag_family = AF_INET;
- auto transaction = channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP);
- transaction.Execute(onMessage);
- }
-
- // Query IPv6 listening sockets.
- {
- message.sdiag_family = AF_INET6;
- auto transaction = channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP);
- transaction.Execute(onMessage);
- }
- }
- catch (const NetlinkTransactionError& e)
- {
- // Log but don't fail - network state might be temporarily unavailable
- LOG_ERROR("Failed to query listening sockets via sock_diag: {}", e.what());
- }
-
- return sockets;
-}
-
int SendRelayListenerSocket(wsl::shared::SocketChannel& channel, int hvSocketPort)
try
{
@@ -263,87 +207,91 @@ try
}
CATCH_RETURN_ERRNO();
-bool IsSameSockAddr(const sockaddr_storage& left, const sockaddr_storage& right)
+extern "C" int OnListenMonitorEvent(void* ctx, void* data, size_t dataSz) noexcept
+try
{
- if (left.ss_family != right.ss_family)
+ auto* channel = static_cast(ctx);
+ const auto* event = static_cast(data);
+
+ sockaddr_storage sock{};
+ if (event->family == AF_INET)
+ {
+ auto* ipv4 = reinterpret_cast(&sock);
+ ipv4->sin_family = AF_INET;
+ ipv4->sin_addr.s_addr = event->addr4;
+ ipv4->sin_port = htons(event->port);
+ }
+ else if (event->family == AF_INET6)
{
- return false;
+ auto* ipv6 = reinterpret_cast(&sock);
+ ipv6->sin6_family = AF_INET6;
+ memcpy(ipv6->sin6_addr.s6_addr, event->addr6, sizeof(ipv6->sin6_addr.s6_addr));
+ ipv6->sin6_port = htons(event->port);
+ }
+ else
+ {
+ return 0;
}
- if (left.ss_family == AF_INET)
+ if (event->is_listen)
{
- auto leftIpv4 = reinterpret_cast(&left);
- auto rightIpv4 = reinterpret_cast(&right);
- return (leftIpv4->sin_addr.s_addr == rightIpv4->sin_addr.s_addr && leftIpv4->sin_port == rightIpv4->sin_port);
+ StartHostListener(*channel, sock);
}
- else if (left.ss_family == AF_INET6)
+ else
{
- auto leftIpv6 = reinterpret_cast(&left);
- auto rightIpv6 = reinterpret_cast(&right);
- return (leftIpv6->sin6_port == rightIpv6->sin6_port && memcmp(&leftIpv6->sin6_addr, &rightIpv6->sin6_addr, sizeof(in6_addr)) == 0);
+ StopHostListener(*channel, sock);
}
- FATAL_ERROR("Unrecognized socket family {}", left.ss_family);
- return false;
+ return 0;
+}
+catch (...)
+{
+ LOG_CAUGHT_EXCEPTION_MSG("Error processing listen monitor event");
+ return 0;
}
-// Monitor listening TCP sockets using sock_diag netlink interface.
int MonitorListeningSockets(wsl::shared::SocketChannel& channel)
{
- NetlinkChannel netlinkChannel(SOCK_RAW, NETLINK_SOCK_DIAG);
- std::vector relays{};
- int result = 0;
-
- for (;;)
+ auto* skel = listen_monitor_bpf__open_and_load();
+ if (!skel)
{
- auto sockets = QueryListeningSockets(netlinkChannel);
+ LOG_ERROR("Failed to open/load listen monitor BPF program, {}", errno);
+ return -1;
+ }
- // Stop any relays that no longer match listening ports.
- std::erase_if(relays, [&](const auto& entry) {
- auto found =
- std::find_if(sockets.begin(), sockets.end(), [&](const auto& socket) { return IsSameSockAddr(entry, socket); });
+ auto destroySkel = wil::scope_exit([&] { listen_monitor_bpf__destroy(skel); });
- bool remove = (found == sockets.end());
- if (remove)
- {
- if (StopHostListener(channel, entry) < 0)
- {
- result = -1;
- }
- }
+ if (listen_monitor_bpf__attach(skel) != 0)
+ {
+ LOG_ERROR("Failed to attach listen monitor BPF program, {}", errno);
+ return -1;
+ }
- return remove;
- });
+ auto* rb = ring_buffer__new(bpf_map__fd(skel->maps.events), OnListenMonitorEvent, &channel, nullptr);
+ if (!rb)
+ {
+ LOG_ERROR("Failed to create ring buffer, {}", errno);
+ return -1;
+ }
- // Create relays for any new ports.
- std::for_each(sockets.begin(), sockets.end(), [&](const auto& socket) {
- auto found =
- std::find_if(relays.begin(), relays.end(), [&](const auto& entry) { return IsSameSockAddr(entry, socket); });
+ auto destroyRb = wil::scope_exit([&] { ring_buffer__free(rb); });
- if (found == relays.end())
- {
- if (StartHostListener(channel, socket) < 0)
- {
- result = -1;
- }
- else
- {
- relays.push_back(socket);
- }
- }
- });
+ GNS_LOG_INFO("BPF listen monitor attached and running");
- // Ensure all start / stop operations were successful.
- if (result < 0)
+ for (;;)
+ {
+ int err = ring_buffer__poll(rb, -1);
+ if (err == -EINTR)
{
- break;
+ continue;
}
- // Sleep before scanning again.
- std::this_thread::sleep_for(std::chrono::seconds(1));
+ if (err < 0)
+ {
+ LOG_ERROR("ring_buffer__poll failed, {}", err);
+ return -1;
+ }
}
-
- return result;
}
} // namespace
@@ -464,9 +412,11 @@ int RunPortTracker(int Argc, char** Argv)
auto channel = NetlinkChannel::FromFd(NetlinkSocketFd);
+ auto channelMutex = std::make_shared();
+
auto seccompDispatcher = std::make_shared(BpfFd);
- GnsPortTracker portTracker(hvSocketChannel, std::move(channel), seccompDispatcher);
+ GnsPortTracker portTracker(hvSocketChannel, std::move(channel), seccompDispatcher, channelMutex);
seccompDispatcher->RegisterHandler(
__NR_bind, [&portTracker](seccomp_notif* notification) { return portTracker.ProcessSecCompNotification(notification); });
@@ -481,7 +431,7 @@ int RunPortTracker(int Argc, char** Argv)
});
#endif
- seccompDispatcher->RegisterHandler(__NR_ioctl, [hvSocketChannel, seccompDispatcher](auto notification) -> int {
+ seccompDispatcher->RegisterHandler(__NR_ioctl, [hvSocketChannel, seccompDispatcher, channelMutex](auto notification) -> int {
LX_GNS_TUN_BRIDGE_REQUEST request{};
request.Header.MessageType = LxGnsMessageIfStateChangeRequest;
request.Header.MessageSize = sizeof(request);
@@ -495,6 +445,8 @@ int RunPortTracker(int Argc, char** Argv)
auto& ifRequest = *reinterpret_cast(ifreqMemory->data());
memcpy(request.InterfaceName, ifRequest.ifr_ifrn.ifrn_name, sizeof(request.InterfaceName));
request.InterfaceUp = ifRequest.ifr_ifru.ifru_flags & IFF_UP;
+
+ std::lock_guard lock(*channelMutex);
const auto& reply = hvSocketChannel->Transaction(request);
return reply.Result;
diff --git a/src/linux/init/main.cpp b/src/linux/init/main.cpp
index f31ebf732..c5c8f4d9a 100644
--- a/src/linux/init/main.cpp
+++ b/src/linux/init/main.cpp
@@ -3540,7 +3540,7 @@ wil::unique_fd RegisterSeccompHook()
Routine Description:
- Register a seccomp notification for bind() & ioctl(*, TUNSETIFF, *) calls.
+ Register a seccomp notification for bind() & ioctl(*, SIOCSIFFLAGS, *) calls.
Arguments:
@@ -3568,7 +3568,7 @@ Return Value:
// If syscall_nr == __NR_bind then goto user_notify: else continue
BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_nr),
BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, __NR_bind, 3, 0),
- // if (syscall_nr == __NR_bind) then continue else goto allow:
+ // if (syscall_nr == __NR_ioctl) then continue else goto allow:
BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, __NR_ioctl, 0, 3),
// if (syscall arg1 == SIOCSIFFLAGS) goto user_notify else goto allow:
BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_arg(1)),