From 8e8ee5e26575c8169d05704aa69fadbcf1213f04 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Fri, 10 Apr 2026 17:30:11 +0800 Subject: [PATCH 1/7] wip --- CMakeLists.txt | 1 + src/linux/bpf/CMakeLists.txt | 124 +++++ src/linux/bpf/bind_monitor.bpf.c | 118 +++++ src/linux/bpf/bind_monitor.h | 36 ++ src/linux/bpf/generate-vmlinux-header.sh | 111 ++++ src/linux/init/CMakeLists.txt | 7 + src/linux/init/GnsPortTracker.cpp | 640 ++--------------------- src/linux/init/GnsPortTracker.h | 166 +----- src/linux/init/localhost.cpp | 26 +- src/linux/init/main.cpp | 51 +- 10 files changed, 458 insertions(+), 822 deletions(-) create mode 100644 src/linux/bpf/CMakeLists.txt create mode 100644 src/linux/bpf/bind_monitor.bpf.c create mode 100644 src/linux/bpf/bind_monitor.h create mode 100644 src/linux/bpf/generate-vmlinux-header.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 2834728e3..b5e8c7c25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,7 @@ find_nuget_package(Microsoft.WSL.DeviceHost WSL_DEVICE_HOST /build/native) find_nuget_package(Microsoft.WSL.Kernel KERNEL /build/native) find_nuget_package(Microsoft.WSL.bsdtar BSDTARD /build/native/bin) find_nuget_package(Microsoft.WSL.LinuxSdk LINUXSDK /) +find_nuget_package(Microsoft.WSL.BPF WSL_BPF /) find_nuget_package(Microsoft.WSL.TestDistro TEST_DISTRO /) find_nuget_package(Microsoft.WSLg WSLG /build/native/bin) find_nuget_package(vswhere VSWHERE /tools) diff --git a/src/linux/bpf/CMakeLists.txt b/src/linux/bpf/CMakeLists.txt new file mode 100644 index 000000000..e069def24 --- /dev/null +++ b/src/linux/bpf/CMakeLists.txt @@ -0,0 +1,124 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# +# CMakeLists.txt for building BPF skeleton headers. +# This is intended to be run on a Linux build machine, NOT as part of the main Windows build. +# +# Required parameters: +# -DKERNEL_IMAGE_PATH= Path to the WSL kernel binary (bzImage/Image) with embedded BTF +# -DBPFTOOL_PATH= Path to the bpftool binary +# +# Optional parameters: +# -DCLANG_BPF= Path to clang (default: clang) +# -DOUTPUT_DIR= Output directory for generated headers (default: ${CMAKE_BINARY_DIR}/output) +# +# Usage: +# cmake -S . -B build -DKERNEL_IMAGE_PATH=/path/to/bzImage -DBPFTOOL_PATH=/usr/sbin/bpftool +# cmake --build build +# +# Output: +# ${OUTPUT_DIR}/bind_monitor.skel.h - BPF skeleton header (embed in NuGet) +# ${OUTPUT_DIR}/bind_monitor.h - Shared event struct header (embed in NuGet) + +cmake_minimum_required(VERSION 3.20) +project(wsl-bpf LANGUAGES NONE) + +# Validate required parameters +if(NOT DEFINED KERNEL_IMAGE_PATH) + message(FATAL_ERROR "KERNEL_IMAGE_PATH is required. Pass -DKERNEL_IMAGE_PATH=") +endif() + +if(NOT EXISTS "${KERNEL_IMAGE_PATH}") + message(FATAL_ERROR "Kernel image not found: ${KERNEL_IMAGE_PATH}") +endif() + +if(NOT DEFINED BPFTOOL_PATH) + message(FATAL_ERROR "BPFTOOL_PATH is required. Pass -DBPFTOOL_PATH=") +endif() + +if(NOT EXISTS "${BPFTOOL_PATH}") + message(FATAL_ERROR "bpftool not found: ${BPFTOOL_PATH}") +endif() + +# Optional parameters +if(NOT DEFINED CLANG_BPF) + set(CLANG_BPF "clang") +endif() + +if(NOT DEFINED OUTPUT_DIR) + set(OUTPUT_DIR "${CMAKE_BINARY_DIR}/output") +endif() + +file(MAKE_DIRECTORY "${OUTPUT_DIR}") + +set(SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}") +set(VMLINUX_H "${CMAKE_BINARY_DIR}/vmlinux.h") +set(BPF_OBJ "${CMAKE_BINARY_DIR}/bind_monitor.bpf.o") +set(SKEL_H "${OUTPUT_DIR}/bind_monitor.skel.h") +set(EVENT_H "${OUTPUT_DIR}/bind_monitor.h") + +# Detect target architecture for BPF +execute_process( + COMMAND uname -m + OUTPUT_VARIABLE HOST_ARCH + OUTPUT_STRIP_TRAILING_WHITESPACE) + +if(HOST_ARCH STREQUAL "x86_64") + set(BPF_TARGET_ARCH "x86") +elseif(HOST_ARCH STREQUAL "aarch64") + set(BPF_TARGET_ARCH "arm64") +else() + message(FATAL_ERROR "Unsupported architecture: ${HOST_ARCH}") +endif() + +# Locate libbpf headers (needed for BPF compilation) +find_path(LIBBPF_INCLUDE_DIR "bpf/bpf_helpers.h" + PATHS /usr/include /usr/local/include + REQUIRED) + +# Step 1: Generate vmlinux.h from kernel image +add_custom_command( + OUTPUT "${VMLINUX_H}" + COMMAND bash "${SRC_DIR}/generate-vmlinux-header.sh" "${KERNEL_IMAGE_PATH}" "${VMLINUX_H}" + DEPENDS "${KERNEL_IMAGE_PATH}" "${SRC_DIR}/generate-vmlinux-header.sh" + COMMENT "Generating vmlinux.h from kernel image" + VERBATIM) + +# Step 2: Compile BPF program +add_custom_command( + OUTPUT "${BPF_OBJ}" + COMMAND "${CLANG_BPF}" -O2 -g -target bpf + -D__TARGET_ARCH_${BPF_TARGET_ARCH} + -I "${CMAKE_BINARY_DIR}" + -I "${LIBBPF_INCLUDE_DIR}" + -I "${SRC_DIR}" + -c "${SRC_DIR}/bind_monitor.bpf.c" + -o "${BPF_OBJ}" + DEPENDS "${SRC_DIR}/bind_monitor.bpf.c" "${SRC_DIR}/bind_monitor.h" "${VMLINUX_H}" + COMMENT "Compiling bind_monitor.bpf.c" + VERBATIM) + +# Step 3: Generate skeleton header +add_custom_command( + OUTPUT "${SKEL_H}" + COMMAND "${BPFTOOL_PATH}" gen skeleton "${BPF_OBJ}" > "${SKEL_H}" + DEPENDS "${BPF_OBJ}" + COMMENT "Generating bind_monitor.skel.h" + VERBATIM) + +# Step 4: Copy shared header to output +add_custom_command( + OUTPUT "${EVENT_H}" + COMMAND ${CMAKE_COMMAND} -E copy "${SRC_DIR}/bind_monitor.h" "${EVENT_H}" + DEPENDS "${SRC_DIR}/bind_monitor.h" + COMMENT "Copying bind_monitor.h to output" + VERBATIM) + +# Main target +add_custom_target(bpf_headers ALL DEPENDS "${SKEL_H}" "${EVENT_H}") + +message(STATUS "BPF build configuration:") +message(STATUS " Kernel image: ${KERNEL_IMAGE_PATH}") +message(STATUS " bpftool: ${BPFTOOL_PATH}") +message(STATUS " clang: ${CLANG_BPF}") +message(STATUS " Target arch: ${BPF_TARGET_ARCH}") +message(STATUS " Output dir: ${OUTPUT_DIR}") diff --git a/src/linux/bpf/bind_monitor.bpf.c b/src/linux/bpf/bind_monitor.bpf.c new file mode 100644 index 000000000..7ede0c8e7 --- /dev/null +++ b/src/linux/bpf/bind_monitor.bpf.c @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include "vmlinux.h" +#include +#include +#include +#include +#include "bind_monitor.h" + +struct { + __uint(type, BPF_MAP_TYPE_RINGBUF); + __uint(max_entries, BIND_MONITOR_RINGBUF_SIZE); +} events SEC(".maps"); + +// Track sockets that went through bind() so we only emit matching releases. +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __type(key, __u64); // sock pointer as key + __type(value, __u8); + __uint(max_entries, 65536); +} bound_sockets SEC(".maps"); + +static __always_inline int emit_event(struct sock *sk, __u8 is_bind, int ret) +{ + struct bind_event *e; + __u16 family; + __u16 protocol; + __u64 sk_key = (__u64)sk; + + if (!sk) + return 0; + + family = BPF_CORE_READ(sk, __sk_common.skc_family); + if (family != AF_INET && family != AF_INET6) + return 0; + + protocol = BPF_CORE_READ(sk, sk_protocol); + if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP) + return 0; + + if (is_bind) + { + // Only emit successful binds. + if (ret != 0) + return 0; + + __u8 val = 1; + bpf_map_update_elem(&bound_sockets, &sk_key, &val, BPF_ANY); + } + else + { + // Only emit release if this socket was previously bound. + if (!bpf_map_lookup_elem(&bound_sockets, &sk_key)) + return 0; + bpf_map_delete_elem(&bound_sockets, &sk_key); + } + + e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); + if (!e) + return 0; + + e->family = family; + e->protocol = protocol; + e->port = bpf_ntohs(BPF_CORE_READ(sk, __sk_common.skc_num)); + e->is_bind = is_bind; + e->pad = 0; + + if (family == AF_INET) + { + struct inet_sock *inet = (struct inet_sock *)sk; + e->addr4 = BPF_CORE_READ(inet, inet_saddr); + __builtin_memset(e->addr6, 0, sizeof(e->addr6)); + } + else + { + struct ipv6_pinfo *pinet6; + e->addr4 = 0; + pinet6 = BPF_CORE_READ((struct inet_sock *)sk, pinet6); + if (pinet6) + BPF_CORE_READ_INTO(e->addr6, pinet6, saddr.in6_u.u6_addr8); + else + __builtin_memset(e->addr6, 0, sizeof(e->addr6)); + } + + bpf_ringbuf_submit(e, 0); + return 0; +} + +// fexit/inet_bind: called after inet_bind() returns. +// int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +SEC("fexit/inet_bind") +int BPF_PROG(fexit_inet_bind, struct socket *sock, struct sockaddr *uaddr, + int addr_len, int ret) +{ + struct sock *sk = BPF_CORE_READ(sock, sk); + return emit_event(sk, 1, ret); +} + +// fexit/inet6_bind: called after inet6_bind() returns. +// int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +SEC("fexit/inet6_bind") +int BPF_PROG(fexit_inet6_bind, struct socket *sock, struct sockaddr *uaddr, + int addr_len, int ret) +{ + struct sock *sk = BPF_CORE_READ(sock, sk); + return emit_event(sk, 1, ret); +} + +// fentry/inet_release: called when a socket is being closed. +// void inet_release(struct socket *sock) +SEC("fentry/inet_release") +int BPF_PROG(fentry_inet_release, struct socket *sock) +{ + struct sock *sk = BPF_CORE_READ(sock, sk); + return emit_event(sk, 0, 0); +} + +char LICENSE[] SEC("license") = "GPL"; diff --git a/src/linux/bpf/bind_monitor.h b/src/linux/bpf/bind_monitor.h new file mode 100644 index 000000000..9f093bda0 --- /dev/null +++ b/src/linux/bpf/bind_monitor.h @@ -0,0 +1,36 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. + +#pragma once + +#ifndef __u8 +typedef unsigned char __u8; +typedef unsigned short __u16; +typedef unsigned int __u32; +typedef unsigned long long __u64; +#endif + +#ifndef AF_INET +#define AF_INET 2 +#endif +#ifndef AF_INET6 +#define AF_INET6 10 +#endif + +#ifndef IPPROTO_TCP +#define IPPROTO_TCP 6 +#endif +#ifndef IPPROTO_UDP +#define IPPROTO_UDP 17 +#endif + +#define BIND_MONITOR_RINGBUF_SIZE (1 << 16) /* 64 KB */ + +struct bind_event { + __u32 family; /* AF_INET or AF_INET6 */ + __u32 protocol; /* IPPROTO_TCP or IPPROTO_UDP */ + __u16 port; /* host byte order */ + __u8 is_bind; /* 1 = bind, 0 = release */ + __u8 pad; + __u32 addr4; /* IPv4 address (network byte order) */ + __u8 addr6[16]; /* IPv6 address */ +}; diff --git a/src/linux/bpf/generate-vmlinux-header.sh b/src/linux/bpf/generate-vmlinux-header.sh new file mode 100644 index 000000000..5c7a73a74 --- /dev/null +++ b/src/linux/bpf/generate-vmlinux-header.sh @@ -0,0 +1,111 @@ +#!/bin/bash +set -e + +if [ $# -lt 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +BZIMAGE="$1" +OUTPUT="$2" +BPFTOOL="${BPFTOOL:-bpftool}" + +TMPDIR=$(mktemp -d) +trap 'rm -rf "$TMPDIR"' EXIT + +VMLINUX="$TMPDIR/vmlinux" + +echo "Extracting vmlinux from $BZIMAGE..." + +# First, try bpftool directly (works if input is already an ELF with BTF) +if "$BPFTOOL" btf dump file "$BZIMAGE" format c > "$OUTPUT" 2>/dev/null; then + LINES=$(wc -l < "$OUTPUT") + if [ "$LINES" -gt 100 ]; then + echo "Done. Generated $OUTPUT ($LINES lines)" + exit 0 + fi +fi + +# Try to find an ELF embedded in the image (e.g., ARM64 Image) +ELF_OFFSET=$(binwalk -y elf "$BZIMAGE" 2>/dev/null | grep -oP '^\d+' | head -1) || true +if [ -n "$ELF_OFFSET" ]; then + echo "Found ELF at offset $ELF_OFFSET" + tail -c +$((ELF_OFFSET + 1)) "$BZIMAGE" > "$VMLINUX" + if file "$VMLINUX" | grep -q 'ELF' && "$BPFTOOL" btf dump file "$VMLINUX" format c > "$OUTPUT" 2>/dev/null; then + LINES=$(wc -l < "$OUTPUT") + if [ "$LINES" -gt 100 ]; then + echo "Done. Generated $OUTPUT ($LINES lines)" + exit 0 + fi + fi +fi + +# Try to find raw BTF data in the image (ARM64 Image stores BTF as raw data) +BTF_RAW="$TMPDIR/btf.raw" +BTF_OFFSET=$(python3 -c " +import struct, sys +data = open(sys.argv[1], 'rb').read() +magic = b'\x9f\xeb' +idx = 0 +while True: + idx = data.find(magic, idx) + if idx == -1: break + if data[idx+2] == 1: # version 1 + hdr_len = struct.unpack_from(' 1000: + print(f'{idx} {total}') + break + idx += 1 +" "$BZIMAGE" 2>/dev/null) || true + +if [ -n "$BTF_OFFSET" ]; then + BTF_OFF=$(echo "$BTF_OFFSET" | awk '{print $1}') + BTF_SIZE=$(echo "$BTF_OFFSET" | awk '{print $2}') + echo "Found raw BTF at offset $BTF_OFF ($BTF_SIZE bytes)" + tail -c +$((BTF_OFF + 1)) "$BZIMAGE" | head -c "$BTF_SIZE" > "$BTF_RAW" + if "$BPFTOOL" btf dump file "$BTF_RAW" format c > "$OUTPUT" 2>/dev/null; then + LINES=$(wc -l < "$OUTPUT") + if [ "$LINES" -gt 100 ]; then + echo "Done. Generated $OUTPUT ($LINES lines)" + exit 0 + fi + fi +fi + +# Find gzip offset and decompress to get the vmlinux ELF +GZIP_OFFSET=$(binwalk -y gzip "$BZIMAGE" 2>/dev/null | grep -oP '^\d+' | head -1) || true + +if [ -z "$GZIP_OFFSET" ]; then + echo "Error: no gzip or ELF payload found in $BZIMAGE" >&2 + exit 1 +fi + +echo "Found gzip payload at offset $GZIP_OFFSET" +tail -c +$((GZIP_OFFSET + 1)) "$BZIMAGE" | zcat > "$VMLINUX" 2>/dev/null || true + +if ! file "$VMLINUX" | grep -q 'ELF'; then + # The gzip payload might contain another layer; try to find ELF inside + INNER="$TMPDIR/inner" + tail -c +$((GZIP_OFFSET + 1)) "$BZIMAGE" | zcat 2>/dev/null > "$INNER" || true + + # Search for ELF magic in decompressed data + ELF_OFFSET=$(grep -a -b -o -P '\x7fELF' "$INNER" 2>/dev/null | head -1 | cut -d: -f1) || true + if [ -n "$ELF_OFFSET" ]; then + tail -c +$((ELF_OFFSET + 1)) "$INNER" > "$VMLINUX" + fi +fi + +if ! file "$VMLINUX" | grep -q 'ELF'; then + echo "Error: could not extract a valid ELF vmlinux from $BZIMAGE" >&2 + exit 1 +fi + +echo "Extracted vmlinux: $(file "$VMLINUX")" +echo "Generating vmlinux.h with bpftool..." +"$BPFTOOL" btf dump file "$VMLINUX" format c > "$OUTPUT" + +LINES=$(wc -l < "$OUTPUT") +echo "Done. Generated $OUTPUT ($LINES lines)" diff --git a/src/linux/init/CMakeLists.txt b/src/linux/init/CMakeLists.txt index d5dee8005..de419af55 100644 --- a/src/linux/init/CMakeLists.txt +++ b/src/linux/init/CMakeLists.txt @@ -46,7 +46,14 @@ set(HEADERS wslpath.h) set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${CMAKE_CURRENT_LIST_DIR}/../netlinkutil") +set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${WSL_BPF_SOURCE_DIR}/bind_monitor/${TARGET_PLATFORM}/include") +set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/include") set(INIT_LIBRARIES ${COMMON_LINUX_LINK_LIBRARIES} netlinkutil plan9 mountutil configfile) +set(INIT_EXTRA_LIBS + ${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/lib/libbpf.a + ${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/lib/libelf.a + ${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/lib/libz.a) +set(LINUX_LDFLAGS ${LINUX_LDFLAGS} ${INIT_EXTRA_LIBS}) add_linux_executable(init "${SOURCES}" "${HEADERS};${COMMON_LINUX_HEADERS}" "${INIT_LIBRARIES}") add_dependencies(init localization) diff --git a/src/linux/init/GnsPortTracker.cpp b/src/linux/init/GnsPortTracker.cpp index 669582a14..92cf1678b 100644 --- a/src/linux/init/GnsPortTracker.cpp +++ b/src/linux/init/GnsPortTracker.cpp @@ -1,639 +1,95 @@ // Copyright (C) Microsoft Corporation. All rights reserved. -#include -#include #include -#include /* Definition of AUDIT_* constants */ -#include -#include -#include -#include -#include -#include "common.h" // Needs to be included before sal.h before of __reserved macro -#include "NetlinkTransactionError.h" +#include +#include +#include "common.h" #include "GnsPortTracker.h" #include "lxinitshared.h" +#include "bind_monitor.skel.h" -constexpr size_t c_bind_timeout_seconds = 60; -constexpr auto c_sock_diag_refresh_delay = std::chrono::milliseconds(500); -constexpr auto c_sock_diag_poll_timeout = std::chrono::milliseconds(10); -constexpr auto c_bpf_poll_timeout = std::chrono::milliseconds(500); +namespace { -GnsPortTracker::GnsPortTracker( - std::shared_ptr hvSocketChannel, NetlinkChannel&& netlinkChannel, std::shared_ptr seccompDispatcher) : - m_hvSocketChannel(std::move(hvSocketChannel)), m_channel(std::move(netlinkChannel)), m_seccompDispatcher(seccompDispatcher) +extern "C" int OnBindMonitorEvent(void* ctx, void* data, size_t dataSz) noexcept { - m_networkNamespace = std::filesystem::read_symlink("/proc/self/ns/net").string(); -} - -void GnsPortTracker::RunPortRefresh() -{ - UtilSetThreadName("GnsPortTracker"); - - // The polling of bound sockets is done in a separate thread because - // sock_diag sometimes fails with EBUSY when a bind() is in progress. - // Doing this in a separate thread allows the main thread not to be delayed - // because of transient sock_diag failures - - for (;;) - { - // Netlink will sometimes return EBUSY. Don't fail for that - try - { - std::promise resume; - auto result = PortRefreshResult{ListAllocatedPorts(), time(nullptr), std::bind(&std::promise::set_value, &resume)}; - m_allocatedPortsRefresh.set_value(result); - - resume.get_future().wait(); - } - catch (const NetlinkTransactionError& e) - { - if (e.Error().value_or(0) != -EBUSY) - { - std::cerr << "Failed to refresh allocated ports, " << e.what() << std::endl; - } - } - - std::this_thread::sleep_for(c_sock_diag_refresh_delay); - } -} - -int GnsPortTracker::ProcessSecCompNotification(seccomp_notif* notification) -{ - seccomp_notif notificationCopy = *notification; - m_request.post(notificationCopy); - return m_reply.get(); -} - -void GnsPortTracker::Run() -{ - // This method consumes seccomp notifications and allows / disallows port allocations - // depending on wsl core's response. - // After dealing with a notification it also looks at the bound ports list to check - // for port deallocation - - std::thread{std::bind(&GnsPortTracker::RunPortRefresh, this)}.detach(); - std::thread{std::bind(&GnsPortTracker::RunDeferredResolve, this)}.detach(); - - auto future = std::make_optional(m_allocatedPortsRefresh.get_future()); - std::optional refreshResult; - - for (;;) - { - std::optional bindCall; - try - { - bindCall = ReadNextRequest(); - } - catch (const std::exception& e) - { - GNS_LOG_ERROR("Failed to read bind request, {}", e.what()); - } - - if (bindCall.has_value()) - { - int result = 0; - if (bindCall->Request.has_value()) - { - PortAllocation& allocationRequest = bindCall->Request.value(); - result = HandleRequest(allocationRequest); - if (result == 0) - { - TrackPort(allocationRequest); - GNS_LOG_INFO( - "Tracking bind call: family ({}) port ({}) protocol ({})", - allocationRequest.Family, - allocationRequest.Port, - allocationRequest.Protocol); - } - } - - try - { - CompleteRequest(bindCall->CallId, result); - } - catch (const std::exception& e) - { - GNS_LOG_ERROR("Failed to complete bind request, {}", e.what()); - } - - if (bindCall->PortZeroBind.has_value()) - { - try - { - std::lock_guard lock(m_deferredMutex); - m_deferredQueue.push_back(std::move(bindCall->PortZeroBind.value())); - m_deferredCv.notify_one(); - } - catch (const std::exception& e) - { - GNS_LOG_ERROR("Failed to queue port-0 bind for deferred resolution, {}", e.what()); - } - } - } - - // If bindCall is empty, then the read() timed out. Look for any closed port - if (future.has_value() && future->wait_for(c_sock_diag_poll_timeout) == std::future_status::ready) - { - refreshResult.emplace(future->get()); - future.reset(); - m_allocatedPortsRefresh = {}; - - // If this loop's iteration had a bind call, it's possible that RefreshAllocatedPort - // was called before the bind called was processed. Make sure that the port list - // is up to date (If this is called, the next block will schedule another refresh) - if (!bindCall.has_value()) - { - OnRefreshAllocatedPorts(refreshResult->Ports, refreshResult->Timestamp); - } - } - - // Process any port-0 binds that the background thread has resolved. - std::deque resolved; - { - std::lock_guard lock(m_resolvedMutex); - resolved.swap(m_resolvedQueue); - } - for (auto& allocation : resolved) - { - const auto result = HandleRequest(allocation); - if (result == 0) - { - TrackPort(std::move(allocation)); - } - else - { - GNS_LOG_ERROR( - "Failed to register resolved port-0 bind: family ({}) port ({}) protocol ({}), error {}", - allocation.Family, - allocation.Port, - allocation.Protocol, - result); - } - } - - // Only look at bound ports if there's something to deallocate to avoid wasting cycles - if (refreshResult.has_value()) - { - if (!m_allocatedPorts.empty()) - { - future = m_allocatedPortsRefresh.get_future(); - refreshResult->Resume(); // This will resume the sock_diag thread - refreshResult.reset(); - } - } - } -} - -std::set GnsPortTracker::ListAllocatedPorts() -{ - std::set ports; + auto* tracker = static_cast(ctx); + const auto* event = static_cast(data); - inet_diag_req_v2 message{}; - message.sdiag_family = AF_INET; - message.sdiag_protocol = IPPROTO_TCP; - message.idiag_states = ~0; - - auto onMessage = [&](const NetlinkResponse& response) { - for (const auto& e : response.Messages(SOCK_DIAG_BY_FAMILY)) - { - const auto* payload = e.Payload(); - in6_addr address = {}; - - if (payload->idiag_family == AF_INET6) - { - static_assert(sizeof(address.s6_addr32) == 16); - static_assert(sizeof(address.s6_addr32) == sizeof(payload->id.idiag_src)); - memcpy(address.s6_addr32, payload->id.idiag_src, sizeof(address.s6_addr32)); - } - else - { - address.s6_addr32[0] = payload->id.idiag_src[0]; - } - - ports.emplace(ntohs(payload->id.idiag_sport), static_cast(payload->idiag_family), static_cast(message.sdiag_protocol), address); - } - }; - - { - auto transaction = m_channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); - transaction.Execute(onMessage); - } - - message.sdiag_family = AF_INET6; + try { - auto transaction = m_channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); - transaction.Execute(onMessage); + tracker->RequestPort(*event); } - - message.sdiag_protocol = IPPROTO_UDP; - message.sdiag_family = AF_INET; + catch (...) { - auto transaction = m_channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); - transaction.Execute(onMessage); + LOG_CAUGHT_EXCEPTION_MSG("Error processing bind monitor event"); } - message.sdiag_family = AF_INET6; - { - auto transaction = m_channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); - transaction.Execute(onMessage); - } - - return ports; + return 0; } -void GnsPortTracker::OnRefreshAllocatedPorts(const std::set& Ports, time_t Timestamp) -{ - // Because there's no way to get notified when the bind() call actually completes, it' possible - // that this method is called before the bind() completion and so the port allocation may not be visible yet. - // To avoid deallocating ports that simply haven't been done allocating yet, m_allocatedPorts stores a timeout - // that prevents deallocating the port unless: - // - // - The port has been seen to be allocated (if so, then the timeout is empty) - // - The timeout has expired - - for (auto it = m_allocatedPorts.begin(); it != m_allocatedPorts.end();) - { - if (Ports.find(it->first) == Ports.end()) - { - if (!it->second.has_value() || it->second.value() < Timestamp) - { - auto result = RequestPort(it->first, false); - if (result != 0) - { - std::cerr << "GnsPortTracker: Failed to deallocate port " << it->first << ", " << result << std::endl; - } - - GNS_LOG_INFO( - "No longer tracking bind call: family ({}) port ({}) protocol ({})", - it->first.Family, - it->first.Port, - it->first.Protocol); +} // namespace - it = m_allocatedPorts.erase(it); - continue; - } - } - else - { - it->second.reset(); // The port is known to be allocated, remove the timeout - } - - it++; - } +GnsPortTracker::GnsPortTracker(std::shared_ptr hvSocketChannel) : + m_hvSocketChannel(std::move(hvSocketChannel)) +{ } -int GnsPortTracker::RequestPort(const PortAllocation& Port, bool allocate) +void GnsPortTracker::RequestPort(const bind_event& Event) { LX_GNS_PORT_ALLOCATION_REQUEST request{}; request.Header.MessageType = LxGnsMessagePortMappingRequest; request.Header.MessageSize = sizeof(request); - request.Af = Port.Family; - request.Protocol = Port.Protocol; - request.Port = Port.Port; - request.Allocate = allocate; - static_assert(sizeof(request.Address32) == 16); - static_assert(sizeof(request.Address32) == sizeof(Port.Address.s6_addr32)); - memcpy(request.Address32, Port.Address.s6_addr32, sizeof(request.Address32)); - - const auto& response = m_hvSocketChannel->Transaction(request); - - return response.Result; -} - -int GnsPortTracker::HandleRequest(const PortAllocation& Port) -{ - // If the port is already allocated, let the call go through and the kernel will - // decide if bind() should succeed or not - // Note: Returning 0 will also cause the port's timeout to be updated - - if (m_allocatedPorts.contains(Port)) - { - GNS_LOG_INFO("Request for a port that's already reserved (family {}, port {}, protocol {})", Port.Family, Port.Port, Port.Protocol); - return 0; - } - - // Ask the host for this port otherwise - const auto error = RequestPort(Port, true); - GNS_LOG_INFO( - "Requested the host for port allocation on port (family {}, port {}, protocol {}) - returned {}", Port.Family, Port.Port, Port.Protocol, error); - return error; -} - -std::optional GnsPortTracker::ReadNextRequest() -{ - // Read the call information - auto request_value = m_request.try_get(c_bpf_poll_timeout); - if (!request_value.has_value()) - { - return {}; - } - - auto callInfo = request_value.value(); - - // This logic needs to be defensive because the calling process is blocked until - // CompleteRequest() is called, so if the call information can't be processed because - // the caller has done something wrong (bad pointer, fd, or protocol), just let it go through - // and the kernel will fail it - - try - { - return GetCallInfo(callInfo.id, callInfo.pid, callInfo.data.arch, callInfo.data.nr, gsl::make_span(callInfo.data.args)); - } - catch (const std::exception& e) - { - GNS_LOG_ERROR("Failed to read bind() call info with ID {} for pid {}, {}", callInfo.id, callInfo.pid, e.what()); - return {{{}, {}, callInfo.id}}; - } -} - -std::optional GnsPortTracker::GetCallInfo( - uint64_t CallId, pid_t Pid, int Arch, int SysCallNumber, const gsl::span& Arguments) -{ - auto ParseSocket = [&](int Socket, size_t AddressPtr, size_t AddressLength) -> std::optional { - if (AddressLength < sizeof(sockaddr)) - { - return {{{}, {}, CallId}}; // Invalid sockaddr. Let it go through. - } - - auto networkNamespace = std::filesystem::read_symlink(std::format("/proc/{}/ns/net", Pid)).string(); - if (networkNamespace != m_networkNamespace) - { - GNS_LOG_INFO("Skipping bind() call for pid {} in network namespace {}", Pid, networkNamespace.c_str()); - return {{{}, {}, CallId}}; // Different network namespace. Let it go through. - } - - auto processMemory = m_seccompDispatcher->ReadProcessMemory(CallId, Pid, AddressPtr, AddressLength); - if (!processMemory.has_value()) - { - throw RuntimeErrorWithSourceLocation("Failed to read process memory"); - } - - sockaddr& address = *reinterpret_cast(processMemory->data()); - - if ((address.sa_family != AF_INET && address.sa_family != AF_INET6) || - (address.sa_family == AF_INET6 && AddressLength < sizeof(sockaddr_in6))) - { - return {{{}, {}, CallId}}; // This is a non IP call, or invalid sockaddr_in6. Let it go through - } - - // Read the port. The port *happens* to be in the same spot in memory for both sockaddr_in - // and sockaddr_in6. To avoid a second memory read, we take advantage of this fact to fetch - // the port from the currently read memory, regardless of the address family. - static_assert(sizeof(sockaddr_in) <= sizeof(sockaddr)); - - const auto* inAddr = reinterpret_cast(&address); - in_port_t port = ntohs(inAddr->sin_port); - if (port == 0) - { - // Port 0 means the kernel will assign an ephemeral port. We can't know - // the port until after the bind() completes, so duplicate the socket fd - // now (while the process is still stopped by seccomp) and defer the - // getsockname() lookup to after CompleteRequest() unblocks it. - try - { - const int protocol = GetSocketProtocol(Pid, Socket); - auto dupFd = DuplicateSocketFd(Pid, Socket); - if (!dupFd) - { - return {{{}, {}, CallId}}; - } - if (!m_seccompDispatcher->ValidateCookie(CallId)) - { - return {{{}, {}, CallId}}; - } - return {{{}, DeferredPortLookup{Pid, std::move(dupFd), protocol}, CallId}}; - } - catch (const std::exception&) - { - return {{{}, {}, CallId}}; // Can't determine protocol, just let it through - } - } - - in6_addr storedAddress = {}; - - if (address.sa_family == AF_INET) - { - storedAddress.s6_addr32[0] = inAddr->sin_addr.s_addr; - } - else - { - const auto* inAddr6 = reinterpret_cast(&address); - memcpy(storedAddress.s6_addr32, inAddr6->sin6_addr.s6_addr32, sizeof(storedAddress.s6_addr32)); - } - - // It's possible that the calling process lied and passed a sockaddr that - // doesn't match the underlying socket family or a bad fd. If that's the case, - // then GetSocketProtocol() will throw - const int protocol = GetSocketProtocol(Pid, Socket); + request.Af = Event.family; + request.Protocol = Event.protocol; + request.Port = Event.port; + request.Allocate = Event.is_bind; - // As GetSocketProtocol interacts with /proc/, to avoid TOCTOU races we need to - // verify that call is still valid (call id is the same thing as cookie) - if (!m_seccompDispatcher->ValidateCookie(CallId)) - { - throw RuntimeErrorWithSourceLocation(std::format("Invalid call id {}", CallId)); - } - - return {{{PortAllocation(port, address.sa_family, protocol, storedAddress)}, {}, CallId}}; - }; -#ifdef __x86_64__ - if (Arch & __AUDIT_ARCH_64BIT) + static_assert(sizeof(request.Address32) == 16); + if (Event.family == AF_INET) { - return ParseSocket(Arguments[0], Arguments[1], Arguments[2]); + request.Address32[0] = Event.addr4; } - // Note: 32bit on x86_64 uses the __NR_socketcall with the first argument - // set to SYS_BIND to make bind system call and the second argument is - // a pointer to a block of memory containing the original arguments. else { - if (Arguments[0] != SYS_BIND) - { - return {{{}, {}, CallId}}; // Not a bind call, just let the call go through - } - // Grab the first 3 parameters - auto processMemory = m_seccompDispatcher->ReadProcessMemory(CallId, Pid, Arguments[1], sizeof(uint32_t) * 3); - if (!processMemory.has_value()) - { - throw RuntimeErrorWithSourceLocation("Failed to read process memory"); - } - - uint32_t* CopiedArguments = reinterpret_cast(processMemory->data()); - return ParseSocket(CopiedArguments[0], CopiedArguments[1], CopiedArguments[2]); + memcpy(request.Address32, Event.addr6, sizeof(request.Address32)); } -#else - return ParseSocket(Arguments[0], Arguments[1], Arguments[2]); -#endif -} - -void GnsPortTracker::CompleteRequest(uint64_t id, int result) -{ - m_reply.post(result); -} -int GnsPortTracker::GetSocketProtocol(int pid, int fd) -{ - const auto path = std::format("/proc/{}/fd/{}", pid, fd); - - // Because there's a race between the time where the buffer size is determined - // and the actual getxattr() call, retry until the buffer size is big enough - std::string protocol; - int result = -1; - do - { - int bufferSize = Syscall(getxattr, path.c_str(), "system.sockprotoname", nullptr, 0); - protocol.resize(std::max(0, bufferSize - 1)); - - result = getxattr(path.c_str(), "system.sockprotoname", protocol.data(), bufferSize); - } while (result < 0 && errno == ERANGE); - - if (result < 0) - { - throw RuntimeErrorWithSourceLocation(std::format("Failed to read protocol for socket: {}, {}", path, errno)); - } - - // In case the size of the attribute shrunk between the two getxattr calls - protocol.resize(std::max(0, result - 1)); - - if (protocol == "TCP" || protocol == "TCPv6") - { - return IPPROTO_TCP; - } - else if (protocol == "UDP" || protocol == "UDPv6") - { - return IPPROTO_UDP; - } - - throw RuntimeErrorWithSourceLocation(std::format("Unexpected IP socket protocol: {}", protocol)); -} - -wil::unique_fd GnsPortTracker::DuplicateSocketFd(pid_t Pid, int SocketFd) -{ - // Duplicate the socket fd from the target process into our address space. - // We cannot use open("/proc/pid/fd/N") for sockets because the symlink target - // (socket:[inode]) is not a valid filesystem path. Use pidfd_getfd() instead. - wil::unique_fd pidFd(static_cast(syscall(SYS_pidfd_open, Pid, 0u))); - if (!pidFd) - { - GNS_LOG_INFO("Port-0 bind: pidfd_open failed for pid {} (errno {})", Pid, errno); - return {}; - } - - wil::unique_fd dupFd(static_cast(syscall(SYS_pidfd_getfd, pidFd.get(), SocketFd, 0u))); - if (!dupFd) - { - GNS_LOG_INFO("Port-0 bind: pidfd_getfd failed for pid {} fd {} (errno {})", Pid, SocketFd, errno); - } - - return dupFd; -} + const auto& response = m_hvSocketChannel->Transaction(request); -void GnsPortTracker::TrackPort(PortAllocation allocation) -try -{ - // Use insert_or_assign so the deallocation timeout is refreshed if the same - // port key is already present (emplace would silently keep the old entry). - m_allocatedPorts.insert_or_assign(std::move(allocation), std::make_optional(time(nullptr) + c_bind_timeout_seconds)); -} -catch (const std::exception& e) -{ - GNS_LOG_ERROR("Failed to track port allocation, {}", e.what()); + GNS_LOG_INFO( + "Port {} request: family ({}) port ({}) protocol ({}) result ({})", + Event.is_bind ? "allocate" : "release", + Event.family, + Event.port, + Event.protocol, + response.Result); } -void GnsPortTracker::RunDeferredResolve() +void GnsPortTracker::Run() { - UtilSetThreadName("GnsPortZero"); + auto* skel = bind_monitor_bpf__open_and_load(); + THROW_LAST_ERROR_IF(!skel); - for (;;) - { - DeferredPortLookup lookup{0, {}, 0}; - { - std::unique_lock lock(m_deferredMutex); - m_deferredCv.wait(lock, [&] { return !m_deferredQueue.empty(); }); - lookup = std::move(m_deferredQueue.front()); - m_deferredQueue.pop_front(); - } + auto destroySkel = wil::scope_exit([&] { bind_monitor_bpf__destroy(skel); }); - const auto pid = lookup.Pid; - try - { - ResolvePortZeroBind(std::move(lookup)); - } - catch (const std::exception& e) - { - GNS_LOG_ERROR("Failed to resolve port-0 bind for pid {}, {}", pid, e.what()); - } - } -} + THROW_LAST_ERROR_IF(bind_monitor_bpf__attach(skel) != 0); -void GnsPortTracker::ResolvePortZeroBind(DeferredPortLookup lookup) -{ - // The socket fd was already duplicated (via pidfd_getfd) while the target process - // was stopped by seccomp, so it remains valid even if the process has closed or - // reused the original fd number. + auto* rb = ring_buffer__new(bpf_map__fd(skel->maps.events), OnBindMonitorEvent, this, nullptr); + THROW_LAST_ERROR_IF(!rb); - // The bind() syscall is being completed asynchronously on the seccomp dispatcher - // thread after CompleteRequest() unblocks it. Poll getsockname() briefly until - // the kernel assigns a port. - constexpr int maxRetries = 25; - constexpr auto retryDelay = std::chrono::milliseconds(100); + auto destroyRb = wil::scope_exit([&] { ring_buffer__free(rb); }); - in_port_t port = 0; - in6_addr address = {}; - int resolvedFamily = 0; + GNS_LOG_INFO("BPF bind monitor attached and running"); - for (int attempt = 0; attempt < maxRetries; ++attempt) + for (;;) { - if (attempt > 0) + int err = ring_buffer__poll(rb, -1 /* block until event */); + if (err == -EINTR) { - std::this_thread::sleep_for(retryDelay); + continue; } - sockaddr_storage storage{}; - socklen_t addrLen = sizeof(storage); - if (getsockname(lookup.DuplicatedSocketFd.get(), reinterpret_cast(&storage), &addrLen) != 0) - { - GNS_LOG_ERROR("Port-0 bind: getsockname failed for pid {} (errno {})", lookup.Pid, errno); - return; - } - - resolvedFamily = static_cast(storage.ss_family); - - if (storage.ss_family == AF_INET) - { - const auto* sin = reinterpret_cast(&storage); - port = ntohs(sin->sin_port); - address.s6_addr32[0] = sin->sin_addr.s_addr; - } - else if (storage.ss_family == AF_INET6) - { - const auto* sin6 = reinterpret_cast(&storage); - port = ntohs(sin6->sin6_port); - memcpy(address.s6_addr32, sin6->sin6_addr.s6_addr32, sizeof(address.s6_addr32)); - } - else - { - GNS_LOG_ERROR("Port-0 bind: unexpected address family ({}) for pid {}", resolvedFamily, lookup.Pid); - return; - } - - if (port != 0) - { - break; - } - } - - if (port == 0) - { - GNS_LOG_ERROR("Port-0 bind: kernel did not assign a port for pid {} after retries", lookup.Pid); - return; - } - - PortAllocation allocation(port, resolvedFamily, lookup.Protocol, address); - GNS_LOG_INFO( - "Port-0 bind resolved: family ({}) port ({}) protocol ({}) for pid {}", resolvedFamily, port, lookup.Protocol, lookup.Pid); - { - std::lock_guard lock(m_resolvedMutex); - m_resolvedQueue.push_back(std::move(allocation)); + THROW_LAST_ERROR_IF(err < 0); } } diff --git a/src/linux/init/GnsPortTracker.h b/src/linux/init/GnsPortTracker.h index a3f2945a1..4da0516c8 100644 --- a/src/linux/init/GnsPortTracker.h +++ b/src/linux/init/GnsPortTracker.h @@ -1,177 +1,23 @@ // Copyright (C) Microsoft Corporation. All rights reserved. #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include "util.h" -#include -#include "waitablevalue.h" -#include "SecCompDispatcher.h" #include "SocketChannel.h" +#include "util.h" +#include "bind_monitor.h" class GnsPortTracker { public: - GnsPortTracker(std::shared_ptr hvSocketChannel, NetlinkChannel&& netlinkChannel, std::shared_ptr seccompDispatcher); + GnsPortTracker(std::shared_ptr hvSocketChannel); - GnsPortTracker(const GnsPortTracker&) = delete; - GnsPortTracker(GnsPortTracker&&) = delete; - GnsPortTracker& operator=(const GnsPortTracker&) = delete; - GnsPortTracker& operator=(GnsPortTracker&&) = delete; + NON_COPYABLE(GnsPortTracker); + NON_MOVABLE(GnsPortTracker); void Run(); - int ProcessSecCompNotification(seccomp_notif* notification); - - struct PortAllocation - { - in6_addr Address = {}; - std::uint16_t Port = {}; - int Family = {}; - int Protocol = {}; - - PortAllocation(PortAllocation&&) = default; - PortAllocation(const PortAllocation&) = default; - - PortAllocation& operator=(PortAllocation&&) = default; - PortAllocation& operator=(const PortAllocation&) = default; - - PortAllocation(std::uint16_t Port, int Family, int Protocol, in6_addr& Address) : - Port(Port), Family(Family), Protocol(Protocol) - { - memcpy(this->Address.s6_addr32, Address.s6_addr32, sizeof(this->Address.s6_addr32)); - } - - bool operator<(const PortAllocation& other) const - { - if (Port < other.Port) - { - return true; - } - else if (Port > other.Port) - { - return false; - } - - if (Family < other.Family) - { - return true; - } - else if (Family > other.Family) - { - return false; - } - - if (Protocol < other.Protocol) - { - return true; - } - else if (Protocol > other.Protocol) - { - return false; - } - - static_assert(sizeof(Address.s6_addr32) == 16); - if (int res = memcmp(Address.s6_addr32, other.Address.s6_addr32, sizeof(Address.s6_addr32)); res < 0) - { - return true; - } - else if (res > 0) - { - return false; - } - - return false; - } - }; - - struct DeferredPortLookup - { - pid_t Pid; - wil::unique_fd DuplicatedSocketFd; // Duplicated via pidfd_getfd while process was stopped - int Protocol; - - DeferredPortLookup(pid_t Pid, wil::unique_fd DuplicatedSocketFd, int Protocol) : - Pid(Pid), DuplicatedSocketFd(std::move(DuplicatedSocketFd)), Protocol(Protocol) - { - } - - DeferredPortLookup(DeferredPortLookup&&) = default; - DeferredPortLookup& operator=(DeferredPortLookup&&) = default; - DeferredPortLookup(const DeferredPortLookup&) = delete; - DeferredPortLookup& operator=(const DeferredPortLookup&) = delete; - }; - - struct BindCall - { - std::optional Request; - std::optional PortZeroBind; - std::uint64_t CallId; - }; - - struct PortRefreshResult - { - std::set Ports; - time_t Timestamp; - std::function Resume; - }; + void RequestPort(const bind_event& Event); private: - void OnRefreshAllocatedPorts(const std::set& Ports, time_t Timestamp); - - void RunPortRefresh(); - - std::set ListAllocatedPorts(); - - std::optional ReadNextRequest(); - - std::optional GetCallInfo(uint64_t CallId, pid_t Pid, int Arch, int SysCallNumber, const gsl::span& Arguments); - - int RequestPort(const PortAllocation& Port, bool Allocate); - - int HandleRequest(const PortAllocation& Request); - - void CompleteRequest(uint64_t Id, int Result); - - static int GetSocketProtocol(int Pid, int Fd); - - static wil::unique_fd DuplicateSocketFd(pid_t Pid, int SocketFd); - - void ResolvePortZeroBind(DeferredPortLookup lookup); - - void RunDeferredResolve(); - - void TrackPort(PortAllocation allocation); - - std::map> m_allocatedPorts; std::shared_ptr m_hvSocketChannel; - NetlinkChannel m_channel; - std::promise m_allocatedPortsRefresh; - - WaitableValue m_request; - WaitableValue m_reply; - - std::shared_ptr m_seccompDispatcher; - - std::string m_networkNamespace; - - std::mutex m_deferredMutex; - std::condition_variable m_deferredCv; - std::deque m_deferredQueue; - - // Resolved port-0 allocations posted by the background RunDeferredResolve thread - // for the main Run() loop to process (keeps SocketChannel access single-threaded). - std::mutex m_resolvedMutex; - std::deque m_resolvedQueue; }; - -std::ostream& operator<<(std::ostream& out, const GnsPortTracker::PortAllocation& portAllocation); diff --git a/src/linux/init/localhost.cpp b/src/linux/init/localhost.cpp index f1af32182..37f934faa 100644 --- a/src/linux/init/localhost.cpp +++ b/src/linux/init/localhost.cpp @@ -403,8 +403,6 @@ int RunPortTracker(int Argc, char** Argv) " fd" " [" INIT_BPF_FD_ARG " fd]" - " [" INIT_NETLINK_FD_ARG - " fd]" " [" INIT_PORT_TRACKER_LOCALHOST_RELAY " fd]\n"; // This is only supported on VM mode. @@ -418,13 +416,11 @@ int RunPortTracker(int Argc, char** Argv) int BpfFd = -1; int PortTrackerFd = -1; - int NetlinkSocketFd = -1; int GuestRelayFd = -1; ArgumentParser parser(Argc, Argv); parser.AddArgument(Integer{BpfFd}, INIT_BPF_FD_ARG); parser.AddArgument(Integer{PortTrackerFd}, INIT_PORT_TRACKER_FD_ARG); - parser.AddArgument(Integer{NetlinkSocketFd}, INIT_NETLINK_FD_ARG); parser.AddArgument(Integer{GuestRelayFd}, INIT_PORT_TRACKER_LOCALHOST_RELAY); try @@ -437,7 +433,7 @@ int RunPortTracker(int Argc, char** Argv) return 1; } - const bool synchronousMode = BpfFd != -1 && NetlinkSocketFd != -1; + const bool synchronousMode = BpfFd != -1; const bool localhostRelay = GuestRelayFd != -1; auto hvSocketChannel = std::make_shared(wil::unique_fd{PortTrackerFd}, "localhost"); @@ -458,29 +454,12 @@ int RunPortTracker(int Argc, char** Argv) if (!synchronousMode) { - std::cerr << "either both or none of --bpf-fd and --netlink-socket can be passed\n"; + std::cerr << "synchronous mode requires --bpf-fd\n"; return 1; } - auto channel = NetlinkChannel::FromFd(NetlinkSocketFd); - auto seccompDispatcher = std::make_shared(BpfFd); - GnsPortTracker portTracker(hvSocketChannel, std::move(channel), seccompDispatcher); - - seccompDispatcher->RegisterHandler( - __NR_bind, [&portTracker](seccomp_notif* notification) { return portTracker.ProcessSecCompNotification(notification); }); - -#ifdef __x86_64__ - seccompDispatcher->RegisterHandler(I386_NR_socketcall, [&portTracker](seccomp_notif* notification) { - return portTracker.ProcessSecCompNotification(notification); - }); -#else - seccompDispatcher->RegisterHandler(ARMV7_NR_bind, [&portTracker](seccomp_notif* notification) { - return portTracker.ProcessSecCompNotification(notification); - }); -#endif - seccompDispatcher->RegisterHandler(__NR_ioctl, [hvSocketChannel, seccompDispatcher](auto notification) -> int { LX_GNS_TUN_BRIDGE_REQUEST request{}; request.Header.MessageType = LxGnsMessageIfStateChangeRequest; @@ -500,6 +479,7 @@ int RunPortTracker(int Argc, char** Argv) return reply.Result; }); + GnsPortTracker portTracker(hvSocketChannel); try { portTracker.Run(); diff --git a/src/linux/init/main.cpp b/src/linux/init/main.cpp index f31ebf732..da3d54812 100644 --- a/src/linux/init/main.cpp +++ b/src/linux/init/main.cpp @@ -1352,25 +1352,12 @@ Return Value: return; } - wil::unique_fd NetlinkSocket{}; wil::unique_fd BpfFd{}; wil::unique_fd GuestRelayFd{}; switch (Type) { case LxMiniInitPortTrackerTypeMirrored: { - - // - // Create a netlink socket before registering the bpf filter so creation of the socket - // does not trigger the filter. - // - - NetlinkSocket = CreateNetlinkSocket(); - if (!NetlinkSocket) - { - return; - } - BpfFd = RegisterSeccompHook(); if (!BpfFd) { @@ -1398,7 +1385,6 @@ Return Value: UtilCreateChildProcess( "PortTracker", [PortTrackerFd = std::move(PortTrackerFd), - NetlinkSocket = std::move(NetlinkSocket), BpfFd = std::move(BpfFd), GuestRelayFd = std::move(GuestRelayFd)]() { execl( @@ -1408,8 +1394,6 @@ Return Value: std::format("{}", PortTrackerFd.get()).c_str(), INIT_BPF_FD_ARG, std::format("{}", BpfFd.get()).c_str(), - INIT_NETLINK_FD_ARG, - std::format("{}", NetlinkSocket.get()).c_str(), INIT_PORT_TRACKER_LOCALHOST_RELAY, std::format("{}", GuestRelayFd.get()).c_str(), NULL); @@ -3540,7 +3524,7 @@ wil::unique_fd RegisterSeccompHook() Routine Description: - Register a seccomp notification for bind() & ioctl(*, TUNSETIFF, *) calls. + Register a seccomp notification for ioctl(*, TUNSETIFF, *) calls. Arguments: @@ -3563,12 +3547,9 @@ Return Value: // 64bit: // If syscall_arch & __AUDIT_ARCH_64BIT then continue else goto :32bit BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_arch), - // For now, notify on all non-native arch - BPF_JUMP(BPF_JMP + BPF_JSET + BPF_K, __AUDIT_ARCH_64BIT, 0, 7), - // If syscall_nr == __NR_bind then goto user_notify: else continue + BPF_JUMP(BPF_JMP + BPF_JSET + BPF_K, __AUDIT_ARCH_64BIT, 0, 5), + // If syscall_nr == __NR_ioctl then continue else goto allow: BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_nr), - BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, __NR_bind, 3, 0), - // if (syscall_nr == __NR_bind) then continue else goto allow: BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, __NR_ioctl, 0, 3), // if (syscall arg1 == SIOCSIFFLAGS) goto user_notify else goto allow: BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_arg(1)), @@ -3580,34 +3561,10 @@ Return Value: // return SECCOMP_RET_ALLOW; BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ALLOW), - // Note: 32bit on x86_64 uses the __NR_socketcall with the first argument - // set to SYS_BIND to make bind system call. -#ifdef __x86_64__ - // 32bit: - // If syscall_nr == __NR_socketcall then continue else goto allow: - BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_nr), - BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, I386_NR_socketcall, 0, 3), - // if syscall arg0 == SYS_BIND then goto user_notify: else goto allow: - BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_arg(0)), - BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, SYS_BIND, 0, 1), - // user_notify: - // return SECCOMP_RET_USER_NOTIF; - BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_USER_NOTIF), - // allow: - // return SECCOMP_RET_ALLOW; - BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ALLOW), -#else - // 32bit: - // If syscall_nr == __NR_bind then goto user_notify: else goto allow: - BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_nr), - BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, ARMV7_NR_bind, 0, 1), - // user_notify: - // return SECCOMP_RET_USER_NOTIF; - BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_USER_NOTIF), + // 32bit: no ioctl interception needed for 32-bit processes. // allow: // return SECCOMP_RET_ALLOW; BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ALLOW), -#endif }; struct sock_fprog Prog = { From ccae9c1196fdfc587674ff4d73a363472a7ad1b9 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Fri, 10 Apr 2026 18:43:59 +0800 Subject: [PATCH 2/7] fix build --- CMakeLists.txt | 3 ++- packages.config | 1 + src/linux/init/GnsPortTracker.cpp | 5 ----- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b5e8c7c25..b840aa577 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,7 +74,8 @@ find_nuget_package(Microsoft.WSL.DeviceHost WSL_DEVICE_HOST /build/native) find_nuget_package(Microsoft.WSL.Kernel KERNEL /build/native) find_nuget_package(Microsoft.WSL.bsdtar BSDTARD /build/native/bin) find_nuget_package(Microsoft.WSL.LinuxSdk LINUXSDK /) -find_nuget_package(Microsoft.WSL.BPF WSL_BPF /) +#find_nuget_package(Microsoft.WSL.BPF WSL_BPF /) +set(WSL_BPF_SOURCE_DIR "${CMAKE_SOURCE_DIR}/packages/Microsoft.WSL.BPF.1.0.0") find_nuget_package(Microsoft.WSL.TestDistro TEST_DISTRO /) find_nuget_package(Microsoft.WSLg WSLG /build/native/bin) find_nuget_package(vswhere VSWHERE /tools) diff --git a/packages.config b/packages.config index 22cc87bb7..a71b31702 100644 --- a/packages.config +++ b/packages.config @@ -16,6 +16,7 @@ + diff --git a/src/linux/init/GnsPortTracker.cpp b/src/linux/init/GnsPortTracker.cpp index 92cf1678b..6c0a4f5d7 100644 --- a/src/linux/init/GnsPortTracker.cpp +++ b/src/linux/init/GnsPortTracker.cpp @@ -92,8 +92,3 @@ void GnsPortTracker::Run() THROW_LAST_ERROR_IF(err < 0); } } - -std::ostream& operator<<(std::ostream& out, const GnsPortTracker::PortAllocation& entry) -{ - return out << "Port=" << entry.Port << ", Family=" << entry.Family << ", Protocol=" << entry.Protocol; -} From 999663f65dcb7a300e9bda73426a432f7eaac241 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Tue, 14 Apr 2026 14:40:15 +0800 Subject: [PATCH 3/7] fix port endian in bpf --- src/linux/bpf/bind_monitor.bpf.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/linux/bpf/bind_monitor.bpf.c b/src/linux/bpf/bind_monitor.bpf.c index 7ede0c8e7..e72deda02 100644 --- a/src/linux/bpf/bind_monitor.bpf.c +++ b/src/linux/bpf/bind_monitor.bpf.c @@ -61,7 +61,7 @@ static __always_inline int emit_event(struct sock *sk, __u8 is_bind, int ret) e->family = family; e->protocol = protocol; - e->port = bpf_ntohs(BPF_CORE_READ(sk, __sk_common.skc_num)); + e->port = BPF_CORE_READ(sk, __sk_common.skc_num); e->is_bind = is_bind; e->pad = 0; From 018110cf1f5fb3037a0f701c3775f7bb546eefda Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Tue, 14 Apr 2026 14:55:41 +0800 Subject: [PATCH 4/7] cache socket information at bind --- src/linux/bpf/bind_monitor.bpf.c | 93 +++++++++++++++++--------------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/src/linux/bpf/bind_monitor.bpf.c b/src/linux/bpf/bind_monitor.bpf.c index e72deda02..478137fee 100644 --- a/src/linux/bpf/bind_monitor.bpf.c +++ b/src/linux/bpf/bind_monitor.bpf.c @@ -12,76 +12,81 @@ struct { __uint(max_entries, BIND_MONITOR_RINGBUF_SIZE); } events SEC(".maps"); -// Track sockets that went through bind() so we only emit matching releases. struct { __uint(type, BPF_MAP_TYPE_HASH); - __type(key, __u64); // sock pointer as key - __type(value, __u8); + __type(key, __u64); // sock pointer as key + __type(value, struct bind_event); // snapshot from bind time __uint(max_entries, 65536); } bound_sockets SEC(".maps"); -static __always_inline int emit_event(struct sock *sk, __u8 is_bind, int ret) +static __always_inline int emit_bind_event(struct sock *sk, int ret) { struct bind_event *e; - __u16 family; - __u16 protocol; + struct bind_event info = {}; __u64 sk_key = (__u64)sk; - if (!sk) + if (!sk || ret != 0) return 0; - family = BPF_CORE_READ(sk, __sk_common.skc_family); - if (family != AF_INET && family != AF_INET6) + info.family = BPF_CORE_READ(sk, __sk_common.skc_family); + if (info.family != AF_INET && info.family != AF_INET6) return 0; - protocol = BPF_CORE_READ(sk, sk_protocol); - if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP) + info.protocol = BPF_CORE_READ(sk, sk_protocol); + if (info.protocol != IPPROTO_TCP && info.protocol != IPPROTO_UDP) return 0; - if (is_bind) - { - // Only emit successful binds. - if (ret != 0) - return 0; + info.port = BPF_CORE_READ(sk, __sk_common.skc_num); + info.is_bind = 1; - __u8 val = 1; - bpf_map_update_elem(&bound_sockets, &sk_key, &val, BPF_ANY); + if (info.family == AF_INET) + { + struct inet_sock *inet = (struct inet_sock *)sk; + info.addr4 = BPF_CORE_READ(inet, inet_saddr); } else { - // Only emit release if this socket was previously bound. - if (!bpf_map_lookup_elem(&bound_sockets, &sk_key)) - return 0; - bpf_map_delete_elem(&bound_sockets, &sk_key); + struct ipv6_pinfo *pinet6; + pinet6 = BPF_CORE_READ((struct inet_sock *)sk, pinet6); + if (pinet6) + BPF_CORE_READ_INTO(info.addr6, pinet6, saddr.in6_u.u6_addr8); } + bpf_map_update_elem(&bound_sockets, &sk_key, &info, BPF_ANY); + e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); if (!e) return 0; - e->family = family; - e->protocol = protocol; - e->port = BPF_CORE_READ(sk, __sk_common.skc_num); - e->is_bind = is_bind; - e->pad = 0; + *e = info; + bpf_ringbuf_submit(e, 0); + return 0; +} - if (family == AF_INET) - { - struct inet_sock *inet = (struct inet_sock *)sk; - e->addr4 = BPF_CORE_READ(inet, inet_saddr); - __builtin_memset(e->addr6, 0, sizeof(e->addr6)); - } - else +static __always_inline int emit_release_event(struct sock *sk) +{ + struct bind_event *e; + struct bind_event *info; + __u64 sk_key = (__u64)sk; + + if (!sk) + return 0; + + info = bpf_map_lookup_elem(&bound_sockets, &sk_key); + if (!info) + return 0; + + e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); + if (!e) { - struct ipv6_pinfo *pinet6; - e->addr4 = 0; - pinet6 = BPF_CORE_READ((struct inet_sock *)sk, pinet6); - if (pinet6) - BPF_CORE_READ_INTO(e->addr6, pinet6, saddr.in6_u.u6_addr8); - else - __builtin_memset(e->addr6, 0, sizeof(e->addr6)); + bpf_map_delete_elem(&bound_sockets, &sk_key); + return 0; } + *e = *info; + e->is_bind = 0; + + bpf_map_delete_elem(&bound_sockets, &sk_key); bpf_ringbuf_submit(e, 0); return 0; } @@ -93,7 +98,7 @@ int BPF_PROG(fexit_inet_bind, struct socket *sock, struct sockaddr *uaddr, int addr_len, int ret) { struct sock *sk = BPF_CORE_READ(sock, sk); - return emit_event(sk, 1, ret); + return emit_bind_event(sk, ret); } // fexit/inet6_bind: called after inet6_bind() returns. @@ -103,7 +108,7 @@ int BPF_PROG(fexit_inet6_bind, struct socket *sock, struct sockaddr *uaddr, int addr_len, int ret) { struct sock *sk = BPF_CORE_READ(sock, sk); - return emit_event(sk, 1, ret); + return emit_bind_event(sk, ret); } // fentry/inet_release: called when a socket is being closed. @@ -112,7 +117,7 @@ SEC("fentry/inet_release") int BPF_PROG(fentry_inet_release, struct socket *sock) { struct sock *sk = BPF_CORE_READ(sock, sk); - return emit_event(sk, 0, 0); + return emit_release_event(sk); } char LICENSE[] SEC("license") = "GPL"; From a70ef773a46b235823c787fc197843edcbee3088 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Tue, 14 Apr 2026 15:59:51 +0800 Subject: [PATCH 5/7] replace listen polling loop with ebpf --- src/linux/bpf/listen_monitor.bpf.c | 123 +++++++++++++++++++ src/linux/bpf/listen_monitor.h | 28 +++++ src/linux/init/CMakeLists.txt | 2 +- src/linux/init/GnsPortTracker.cpp | 45 ++++--- src/linux/init/GnsPortTracker.h | 3 +- src/linux/init/localhost.cpp | 187 ++++++++++------------------- 6 files changed, 241 insertions(+), 147 deletions(-) create mode 100644 src/linux/bpf/listen_monitor.bpf.c create mode 100644 src/linux/bpf/listen_monitor.h diff --git a/src/linux/bpf/listen_monitor.bpf.c b/src/linux/bpf/listen_monitor.bpf.c new file mode 100644 index 000000000..b89bada00 --- /dev/null +++ b/src/linux/bpf/listen_monitor.bpf.c @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include "vmlinux.h" +#include +#include +#include +#include +#include "listen_monitor.h" + +struct { + __uint(type, BPF_MAP_TYPE_RINGBUF); + __uint(max_entries, LISTEN_MONITOR_RINGBUF_SIZE); +} events SEC(".maps"); + +// Store listen info at listen() time so release can emit correct values +// even if the kernel has already cleared the socket fields. +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __type(key, __u64); // sock pointer as key + __type(value, struct listen_event); // snapshot from listen time + __uint(max_entries, 65536); +} listening_sockets SEC(".maps"); + +static __always_inline int read_sock_info(struct sock *sk, struct listen_event *info) +{ + info->family = BPF_CORE_READ(sk, __sk_common.skc_family); + if (info->family != AF_INET && info->family != AF_INET6) + return -1; + + info->port = BPF_CORE_READ(sk, __sk_common.skc_num); + + if (info->family == AF_INET) + { + struct inet_sock *inet = (struct inet_sock *)sk; + info->addr4 = BPF_CORE_READ(inet, inet_saddr); + __builtin_memset(info->addr6, 0, sizeof(info->addr6)); + } + else + { + struct ipv6_pinfo *pinet6; + info->addr4 = 0; + pinet6 = BPF_CORE_READ((struct inet_sock *)sk, pinet6); + if (pinet6) + BPF_CORE_READ_INTO(info->addr6, pinet6, saddr.in6_u.u6_addr8); + else + __builtin_memset(info->addr6, 0, sizeof(info->addr6)); + } + + return 0; +} + +// fexit/inet_listen: called after inet_listen() returns. +// inet_listen handles both IPv4 and IPv6. +// int inet_listen(struct socket *sock, int backlog) +SEC("fexit/inet_listen") +int BPF_PROG(fexit_inet_listen, struct socket *sock, int backlog, int ret) +{ + struct listen_event *e; + struct listen_event info = {}; + __u64 sk_key; + struct sock *sk; + + if (ret != 0) + return 0; + + sk = BPF_CORE_READ(sock, sk); + if (!sk) + return 0; + + sk_key = (__u64)sk; + + if (read_sock_info(sk, &info) < 0) + return 0; + + info.is_listen = 1; + + bpf_map_update_elem(&listening_sockets, &sk_key, &info, BPF_ANY); + + e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); + if (!e) + return 0; + + *e = info; + bpf_ringbuf_submit(e, 0); + return 0; +} + +// fentry/inet_release: called when a socket is being closed. +// void inet_release(struct socket *sock) +SEC("fentry/inet_release") +int BPF_PROG(fentry_inet_release, struct socket *sock) +{ + struct listen_event *e; + struct listen_event *info; + struct sock *sk; + __u64 sk_key; + + sk = BPF_CORE_READ(sock, sk); + if (!sk) + return 0; + + sk_key = (__u64)sk; + + info = bpf_map_lookup_elem(&listening_sockets, &sk_key); + if (!info) + return 0; + + e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); + if (!e) + { + bpf_map_delete_elem(&listening_sockets, &sk_key); + return 0; + } + + *e = *info; + e->is_listen = 0; + + bpf_map_delete_elem(&listening_sockets, &sk_key); + bpf_ringbuf_submit(e, 0); + return 0; +} + +char LICENSE[] SEC("license") = "GPL"; diff --git a/src/linux/bpf/listen_monitor.h b/src/linux/bpf/listen_monitor.h new file mode 100644 index 000000000..58e684dcd --- /dev/null +++ b/src/linux/bpf/listen_monitor.h @@ -0,0 +1,28 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. + +#pragma once + +#ifndef __u8 +typedef unsigned char __u8; +typedef unsigned short __u16; +typedef unsigned int __u32; +typedef unsigned long long __u64; +#endif + +#ifndef AF_INET +#define AF_INET 2 +#endif +#ifndef AF_INET6 +#define AF_INET6 10 +#endif + +#define LISTEN_MONITOR_RINGBUF_SIZE (1 << 16) /* 64 KB */ + +struct listen_event { + __u32 family; /* AF_INET or AF_INET6 */ + __u16 port; /* host byte order */ + __u8 is_listen; /* 1 = started listening, 0 = stopped listening */ + __u8 pad; + __u32 addr4; /* IPv4 address (network byte order) */ + __u8 addr6[16]; /* IPv6 address */ +}; diff --git a/src/linux/init/CMakeLists.txt b/src/linux/init/CMakeLists.txt index de419af55..ef19f6812 100644 --- a/src/linux/init/CMakeLists.txt +++ b/src/linux/init/CMakeLists.txt @@ -46,7 +46,7 @@ set(HEADERS wslpath.h) set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${CMAKE_CURRENT_LIST_DIR}/../netlinkutil") -set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${WSL_BPF_SOURCE_DIR}/bind_monitor/${TARGET_PLATFORM}/include") +set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${WSL_BPF_SOURCE_DIR}/bpf/${TARGET_PLATFORM}/include") set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/include") set(INIT_LIBRARIES ${COMMON_LINUX_LINK_LIBRARIES} netlinkutil plan9 mountutil configfile) set(INIT_EXTRA_LIBS diff --git a/src/linux/init/GnsPortTracker.cpp b/src/linux/init/GnsPortTracker.cpp index 6c0a4f5d7..2399662aa 100644 --- a/src/linux/init/GnsPortTracker.cpp +++ b/src/linux/init/GnsPortTracker.cpp @@ -6,24 +6,21 @@ #include "common.h" #include "GnsPortTracker.h" #include "lxinitshared.h" +#include "bind_monitor.h" #include "bind_monitor.skel.h" namespace { extern "C" int OnBindMonitorEvent(void* ctx, void* data, size_t dataSz) noexcept +try { auto* tracker = static_cast(ctx); - const auto* event = static_cast(data); - - try - { - tracker->RequestPort(*event); - } - catch (...) - { - LOG_CAUGHT_EXCEPTION_MSG("Error processing bind monitor event"); - } - + tracker->RequestPort(data); + return 0; +} +catch (...) +{ + LOG_CAUGHT_EXCEPTION_MSG("Error processing bind monitor event"); return 0; } @@ -34,34 +31,36 @@ GnsPortTracker::GnsPortTracker(std::shared_ptr hvSoc { } -void GnsPortTracker::RequestPort(const bind_event& Event) +void GnsPortTracker::RequestPort(void* Data) { + const auto* Event = static_cast(Data); + LX_GNS_PORT_ALLOCATION_REQUEST request{}; request.Header.MessageType = LxGnsMessagePortMappingRequest; request.Header.MessageSize = sizeof(request); - request.Af = Event.family; - request.Protocol = Event.protocol; - request.Port = Event.port; - request.Allocate = Event.is_bind; + request.Af = Event->family; + request.Protocol = Event->protocol; + request.Port = Event->port; + request.Allocate = Event->is_bind; static_assert(sizeof(request.Address32) == 16); - if (Event.family == AF_INET) + if (Event->family == AF_INET) { - request.Address32[0] = Event.addr4; + request.Address32[0] = Event->addr4; } else { - memcpy(request.Address32, Event.addr6, sizeof(request.Address32)); + memcpy(request.Address32, Event->addr6, sizeof(request.Address32)); } const auto& response = m_hvSocketChannel->Transaction(request); GNS_LOG_INFO( "Port {} request: family ({}) port ({}) protocol ({}) result ({})", - Event.is_bind ? "allocate" : "release", - Event.family, - Event.port, - Event.protocol, + Event->is_bind ? "allocate" : "release", + Event->family, + Event->port, + Event->protocol, response.Result); } diff --git a/src/linux/init/GnsPortTracker.h b/src/linux/init/GnsPortTracker.h index 4da0516c8..44816efc1 100644 --- a/src/linux/init/GnsPortTracker.h +++ b/src/linux/init/GnsPortTracker.h @@ -4,7 +4,6 @@ #include #include "SocketChannel.h" #include "util.h" -#include "bind_monitor.h" class GnsPortTracker { @@ -16,7 +15,7 @@ class GnsPortTracker void Run(); - void RequestPort(const bind_event& Event); + void RequestPort(void* data); private: std::shared_ptr m_hvSocketChannel; diff --git a/src/linux/init/localhost.cpp b/src/linux/init/localhost.cpp index 37f934faa..381105783 100644 --- a/src/linux/init/localhost.cpp +++ b/src/linux/init/localhost.cpp @@ -12,21 +12,20 @@ #include #include #include -#include -#include #include +#include #include +#include + #include "util.h" #include "SocketChannel.h" #include "GnsPortTracker.h" #include "SecCompDispatcher.h" #include "seccomp_defs.h" #include "CommandLine.h" -#include "NetlinkChannel.h" -#include "NetlinkTransactionError.h" - -#define TCP_LISTEN 10 +#include "listen_monitor.h" +#include "listen_monitor.skel.h" namespace { @@ -149,64 +148,6 @@ void ListenThread(sockaddr_vm hvSocketAddress, int listenSocket) return; } -std::vector QueryListeningSockets(NetlinkChannel& channel) -{ - std::vector sockets{}; - try - { - inet_diag_req_v2 message{}; - message.sdiag_protocol = IPPROTO_TCP; - message.idiag_states = (1 << TCP_LISTEN); - - auto onMessage = [&](const NetlinkResponse& response) { - for (const auto& e : response.Messages(SOCK_DIAG_BY_FAMILY)) - { - const auto* payload = e.Payload(); - sockaddr_storage sock{}; - - if (payload->idiag_family == AF_INET) - { - auto* ipv4 = reinterpret_cast(&sock); - ipv4->sin_family = AF_INET; - ipv4->sin_addr.s_addr = payload->id.idiag_src[0]; - ipv4->sin_port = payload->id.idiag_sport; - } - else if (payload->idiag_family == AF_INET6) - { - auto* ipv6 = reinterpret_cast(&sock); - ipv6->sin6_family = AF_INET6; - static_assert(sizeof(ipv6->sin6_addr.s6_addr32) == sizeof(payload->id.idiag_src)); - memcpy(ipv6->sin6_addr.s6_addr32, payload->id.idiag_src, sizeof(ipv6->sin6_addr.s6_addr32)); - ipv6->sin6_port = payload->id.idiag_sport; - } - - sockets.emplace_back(sock); - } - }; - - // Query IPv4 listening sockets. - { - message.sdiag_family = AF_INET; - auto transaction = channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); - transaction.Execute(onMessage); - } - - // Query IPv6 listening sockets. - { - message.sdiag_family = AF_INET6; - auto transaction = channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); - transaction.Execute(onMessage); - } - } - catch (const NetlinkTransactionError& e) - { - // Log but don't fail - network state might be temporarily unavailable - LOG_ERROR("Failed to query listening sockets via sock_diag: {}", e.what()); - } - - return sockets; -} - int SendRelayListenerSocket(wsl::shared::SocketChannel& channel, int hvSocketPort) try { @@ -263,87 +204,91 @@ try } CATCH_RETURN_ERRNO(); -bool IsSameSockAddr(const sockaddr_storage& left, const sockaddr_storage& right) +extern "C" int OnListenMonitorEvent(void* ctx, void* data, size_t dataSz) noexcept +try { - if (left.ss_family != right.ss_family) + auto* channel = static_cast(ctx); + const auto* event = static_cast(data); + + sockaddr_storage sock{}; + if (event->family == AF_INET) + { + auto* ipv4 = reinterpret_cast(&sock); + ipv4->sin_family = AF_INET; + ipv4->sin_addr.s_addr = event->addr4; + ipv4->sin_port = htons(event->port); + } + else if (event->family == AF_INET6) + { + auto* ipv6 = reinterpret_cast(&sock); + ipv6->sin6_family = AF_INET6; + memcpy(ipv6->sin6_addr.s6_addr, event->addr6, sizeof(ipv6->sin6_addr.s6_addr)); + ipv6->sin6_port = htons(event->port); + } + else { - return false; + return 0; } - if (left.ss_family == AF_INET) + if (event->is_listen) { - auto leftIpv4 = reinterpret_cast(&left); - auto rightIpv4 = reinterpret_cast(&right); - return (leftIpv4->sin_addr.s_addr == rightIpv4->sin_addr.s_addr && leftIpv4->sin_port == rightIpv4->sin_port); + StartHostListener(*channel, sock); } - else if (left.ss_family == AF_INET6) + else { - auto leftIpv6 = reinterpret_cast(&left); - auto rightIpv6 = reinterpret_cast(&right); - return (leftIpv6->sin6_port == rightIpv6->sin6_port && memcmp(&leftIpv6->sin6_addr, &rightIpv6->sin6_addr, sizeof(in6_addr)) == 0); + StopHostListener(*channel, sock); } - FATAL_ERROR("Unrecognized socket family {}", left.ss_family); - return false; + return 0; +} +catch (...) +{ + LOG_CAUGHT_EXCEPTION_MSG("Error processing listen monitor event"); + return 0; } -// Monitor listening TCP sockets using sock_diag netlink interface. int MonitorListeningSockets(wsl::shared::SocketChannel& channel) { - NetlinkChannel netlinkChannel(SOCK_RAW, NETLINK_SOCK_DIAG); - std::vector relays{}; - int result = 0; - - for (;;) + auto* skel = listen_monitor_bpf__open_and_load(); + if (!skel) { - auto sockets = QueryListeningSockets(netlinkChannel); + LOG_ERROR("Failed to open/load listen monitor BPF program, {}", errno); + return -1; + } - // Stop any relays that no longer match listening ports. - std::erase_if(relays, [&](const auto& entry) { - auto found = - std::find_if(sockets.begin(), sockets.end(), [&](const auto& socket) { return IsSameSockAddr(entry, socket); }); + auto destroySkel = wil::scope_exit([&] { listen_monitor_bpf__destroy(skel); }); - bool remove = (found == sockets.end()); - if (remove) - { - if (StopHostListener(channel, entry) < 0) - { - result = -1; - } - } + if (listen_monitor_bpf__attach(skel) != 0) + { + LOG_ERROR("Failed to attach listen monitor BPF program, {}", errno); + return -1; + } - return remove; - }); + auto* rb = ring_buffer__new(bpf_map__fd(skel->maps.events), OnListenMonitorEvent, &channel, nullptr); + if (!rb) + { + LOG_ERROR("Failed to create ring buffer, {}", errno); + return -1; + } - // Create relays for any new ports. - std::for_each(sockets.begin(), sockets.end(), [&](const auto& socket) { - auto found = - std::find_if(relays.begin(), relays.end(), [&](const auto& entry) { return IsSameSockAddr(entry, socket); }); + auto destroyRb = wil::scope_exit([&] { ring_buffer__free(rb); }); - if (found == relays.end()) - { - if (StartHostListener(channel, socket) < 0) - { - result = -1; - } - else - { - relays.push_back(socket); - } - } - }); + GNS_LOG_INFO("BPF listen monitor attached and running"); - // Ensure all start / stop operations were successful. - if (result < 0) + for (;;) + { + int err = ring_buffer__poll(rb, -1); + if (err == -EINTR) { - break; + continue; } - // Sleep before scanning again. - std::this_thread::sleep_for(std::chrono::seconds(1)); + if (err < 0) + { + LOG_ERROR("ring_buffer__poll failed, {}", err); + return -1; + } } - - return result; } } // namespace From 79a77ec0f4ceff9af483d79b0c56b0bfd8dd5563 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Tue, 14 Apr 2026 18:15:37 +0800 Subject: [PATCH 6/7] protect port tracker channel --- src/linux/init/GnsPortTracker.cpp | 9 ++++++--- src/linux/init/GnsPortTracker.h | 4 +++- src/linux/init/localhost.cpp | 9 +++++++-- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/linux/init/GnsPortTracker.cpp b/src/linux/init/GnsPortTracker.cpp index 2399662aa..6560f558e 100644 --- a/src/linux/init/GnsPortTracker.cpp +++ b/src/linux/init/GnsPortTracker.cpp @@ -26,8 +26,8 @@ catch (...) } // namespace -GnsPortTracker::GnsPortTracker(std::shared_ptr hvSocketChannel) : - m_hvSocketChannel(std::move(hvSocketChannel)) +GnsPortTracker::GnsPortTracker(std::shared_ptr hvSocketChannel, std::shared_ptr channelMutex) : + m_hvSocketChannel(std::move(hvSocketChannel)), m_channelMutex(std::move(channelMutex)) { } @@ -53,7 +53,10 @@ void GnsPortTracker::RequestPort(void* Data) memcpy(request.Address32, Event->addr6, sizeof(request.Address32)); } - const auto& response = m_hvSocketChannel->Transaction(request); + const auto& response = [&]() { + std::lock_guard lock(*m_channelMutex); + return m_hvSocketChannel->Transaction(request); + }(); GNS_LOG_INFO( "Port {} request: family ({}) port ({}) protocol ({}) result ({})", diff --git a/src/linux/init/GnsPortTracker.h b/src/linux/init/GnsPortTracker.h index 44816efc1..1ddbbe71f 100644 --- a/src/linux/init/GnsPortTracker.h +++ b/src/linux/init/GnsPortTracker.h @@ -2,13 +2,14 @@ #pragma once #include +#include #include "SocketChannel.h" #include "util.h" class GnsPortTracker { public: - GnsPortTracker(std::shared_ptr hvSocketChannel); + GnsPortTracker(std::shared_ptr hvSocketChannel, std::shared_ptr channelMutex); NON_COPYABLE(GnsPortTracker); NON_MOVABLE(GnsPortTracker); @@ -19,4 +20,5 @@ class GnsPortTracker private: std::shared_ptr m_hvSocketChannel; + std::shared_ptr m_channelMutex; }; diff --git a/src/linux/init/localhost.cpp b/src/linux/init/localhost.cpp index 381105783..38e5d3218 100644 --- a/src/linux/init/localhost.cpp +++ b/src/linux/init/localhost.cpp @@ -1,6 +1,7 @@ // Copyright (C) Microsoft Corporation. All rights reserved. #include "common.h" #include +#include #include #include #include @@ -403,9 +404,11 @@ int RunPortTracker(int Argc, char** Argv) return 1; } + auto channelMutex = std::make_shared(); + auto seccompDispatcher = std::make_shared(BpfFd); - seccompDispatcher->RegisterHandler(__NR_ioctl, [hvSocketChannel, seccompDispatcher](auto notification) -> int { + seccompDispatcher->RegisterHandler(__NR_ioctl, [hvSocketChannel, seccompDispatcher, channelMutex](auto notification) -> int { LX_GNS_TUN_BRIDGE_REQUEST request{}; request.Header.MessageType = LxGnsMessageIfStateChangeRequest; request.Header.MessageSize = sizeof(request); @@ -419,12 +422,14 @@ int RunPortTracker(int Argc, char** Argv) auto& ifRequest = *reinterpret_cast(ifreqMemory->data()); memcpy(request.InterfaceName, ifRequest.ifr_ifrn.ifrn_name, sizeof(request.InterfaceName)); request.InterfaceUp = ifRequest.ifr_ifru.ifru_flags & IFF_UP; + + std::lock_guard lock(*channelMutex); const auto& reply = hvSocketChannel->Transaction(request); return reply.Result; }); - GnsPortTracker portTracker(hvSocketChannel); + GnsPortTracker portTracker(hvSocketChannel, channelMutex); try { portTracker.Run(); From 57c1679e0a3ecbf86268a2034c1810eeefb4d89b Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Wed, 15 Apr 2026 14:16:07 +0800 Subject: [PATCH 7/7] hybrid solution --- src/linux/bpf/bind_monitor.bpf.c | 83 +---- src/linux/bpf/bind_monitor.h | 3 +- src/linux/init/GnsPortTracker.cpp | 567 +++++++++++++++++++++++++++--- src/linux/init/GnsPortTracker.h | 138 +++++++- src/linux/init/localhost.cpp | 28 +- src/linux/init/main.cpp | 51 ++- 6 files changed, 744 insertions(+), 126 deletions(-) diff --git a/src/linux/bpf/bind_monitor.bpf.c b/src/linux/bpf/bind_monitor.bpf.c index 478137fee..4a2eaa586 100644 --- a/src/linux/bpf/bind_monitor.bpf.c +++ b/src/linux/bpf/bind_monitor.bpf.c @@ -12,87 +12,52 @@ struct { __uint(max_entries, BIND_MONITOR_RINGBUF_SIZE); } events SEC(".maps"); -struct { - __uint(type, BPF_MAP_TYPE_HASH); - __type(key, __u64); // sock pointer as key - __type(value, struct bind_event); // snapshot from bind time - __uint(max_entries, 65536); -} bound_sockets SEC(".maps"); - static __always_inline int emit_bind_event(struct sock *sk, int ret) { struct bind_event *e; - struct bind_event info = {}; - __u64 sk_key = (__u64)sk; if (!sk || ret != 0) return 0; - info.family = BPF_CORE_READ(sk, __sk_common.skc_family); - if (info.family != AF_INET && info.family != AF_INET6) + __u16 family = BPF_CORE_READ(sk, __sk_common.skc_family); + if (family != AF_INET && family != AF_INET6) return 0; - info.protocol = BPF_CORE_READ(sk, sk_protocol); - if (info.protocol != IPPROTO_TCP && info.protocol != IPPROTO_UDP) + __u16 protocol = BPF_CORE_READ(sk, sk_protocol); + if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP) return 0; - info.port = BPF_CORE_READ(sk, __sk_common.skc_num); - info.is_bind = 1; + e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); + if (!e) + return 0; - if (info.family == AF_INET) + e->family = family; + e->protocol = protocol; + e->port = BPF_CORE_READ(sk, __sk_common.skc_num); + e->pad = 0; + + if (family == AF_INET) { struct inet_sock *inet = (struct inet_sock *)sk; - info.addr4 = BPF_CORE_READ(inet, inet_saddr); + e->addr4 = BPF_CORE_READ(inet, inet_saddr); + __builtin_memset(e->addr6, 0, sizeof(e->addr6)); } else { struct ipv6_pinfo *pinet6; + e->addr4 = 0; pinet6 = BPF_CORE_READ((struct inet_sock *)sk, pinet6); if (pinet6) - BPF_CORE_READ_INTO(info.addr6, pinet6, saddr.in6_u.u6_addr8); + BPF_CORE_READ_INTO(e->addr6, pinet6, saddr.in6_u.u6_addr8); + else + __builtin_memset(e->addr6, 0, sizeof(e->addr6)); } - bpf_map_update_elem(&bound_sockets, &sk_key, &info, BPF_ANY); - - e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); - if (!e) - return 0; - - *e = info; - bpf_ringbuf_submit(e, 0); - return 0; -} - -static __always_inline int emit_release_event(struct sock *sk) -{ - struct bind_event *e; - struct bind_event *info; - __u64 sk_key = (__u64)sk; - - if (!sk) - return 0; - - info = bpf_map_lookup_elem(&bound_sockets, &sk_key); - if (!info) - return 0; - - e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); - if (!e) - { - bpf_map_delete_elem(&bound_sockets, &sk_key); - return 0; - } - - *e = *info; - e->is_bind = 0; - - bpf_map_delete_elem(&bound_sockets, &sk_key); bpf_ringbuf_submit(e, 0); return 0; } // fexit/inet_bind: called after inet_bind() returns. -// int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) SEC("fexit/inet_bind") int BPF_PROG(fexit_inet_bind, struct socket *sock, struct sockaddr *uaddr, int addr_len, int ret) @@ -102,7 +67,6 @@ int BPF_PROG(fexit_inet_bind, struct socket *sock, struct sockaddr *uaddr, } // fexit/inet6_bind: called after inet6_bind() returns. -// int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) SEC("fexit/inet6_bind") int BPF_PROG(fexit_inet6_bind, struct socket *sock, struct sockaddr *uaddr, int addr_len, int ret) @@ -111,13 +75,4 @@ int BPF_PROG(fexit_inet6_bind, struct socket *sock, struct sockaddr *uaddr, return emit_bind_event(sk, ret); } -// fentry/inet_release: called when a socket is being closed. -// void inet_release(struct socket *sock) -SEC("fentry/inet_release") -int BPF_PROG(fentry_inet_release, struct socket *sock) -{ - struct sock *sk = BPF_CORE_READ(sock, sk); - return emit_release_event(sk); -} - char LICENSE[] SEC("license") = "GPL"; diff --git a/src/linux/bpf/bind_monitor.h b/src/linux/bpf/bind_monitor.h index 9f093bda0..6138e4a48 100644 --- a/src/linux/bpf/bind_monitor.h +++ b/src/linux/bpf/bind_monitor.h @@ -29,8 +29,7 @@ struct bind_event { __u32 family; /* AF_INET or AF_INET6 */ __u32 protocol; /* IPPROTO_TCP or IPPROTO_UDP */ __u16 port; /* host byte order */ - __u8 is_bind; /* 1 = bind, 0 = release */ - __u8 pad; + __u16 pad; __u32 addr4; /* IPv4 address (network byte order) */ __u8 addr6[16]; /* IPv6 address */ }; diff --git a/src/linux/init/GnsPortTracker.cpp b/src/linux/init/GnsPortTracker.cpp index 6560f558e..dfa0ca145 100644 --- a/src/linux/init/GnsPortTracker.cpp +++ b/src/linux/init/GnsPortTracker.cpp @@ -1,96 +1,565 @@ // Copyright (C) Microsoft Corporation. All rights reserved. +#include +#include #include -#include +#include +#include +#include +#include +#include +#include #include #include "common.h" +#include "NetlinkTransactionError.h" #include "GnsPortTracker.h" #include "lxinitshared.h" #include "bind_monitor.h" #include "bind_monitor.skel.h" -namespace { +constexpr size_t c_bind_timeout_seconds = 60; +constexpr auto c_sock_diag_refresh_delay = std::chrono::milliseconds(500); +constexpr auto c_sock_diag_poll_timeout = std::chrono::milliseconds(10); +constexpr auto c_bpf_poll_timeout = std::chrono::milliseconds(500); -extern "C" int OnBindMonitorEvent(void* ctx, void* data, size_t dataSz) noexcept -try +GnsPortTracker::GnsPortTracker( + std::shared_ptr hvSocketChannel, + NetlinkChannel&& netlinkChannel, + std::shared_ptr seccompDispatcher, + std::shared_ptr channelMutex) : + m_hvSocketChannel(std::move(hvSocketChannel)), + m_channelMutex(std::move(channelMutex)), + m_channel(std::move(netlinkChannel)), + m_seccompDispatcher(seccompDispatcher) { - auto* tracker = static_cast(ctx); - tracker->RequestPort(data); - return 0; + m_networkNamespace = std::filesystem::read_symlink("/proc/self/ns/net").string(); } -catch (...) + +void GnsPortTracker::RunPortRefresh() { - LOG_CAUGHT_EXCEPTION_MSG("Error processing bind monitor event"); - return 0; + UtilSetThreadName("GnsPortTracker"); + + // The polling of bound sockets is done in a separate thread because + // sock_diag sometimes fails with EBUSY when a bind() is in progress. + // Doing this in a separate thread allows the main thread not to be delayed + // because of transient sock_diag failures + + for (;;) + { + // Netlink will sometimes return EBUSY. Don't fail for that + try + { + std::promise resume; + auto result = PortRefreshResult{ListAllocatedPorts(), time(nullptr), std::bind(&std::promise::set_value, &resume)}; + m_allocatedPortsRefresh.set_value(result); + + resume.get_future().wait(); + } + catch (const NetlinkTransactionError& e) + { + if (e.Error().value_or(0) != -EBUSY) + { + std::cerr << "Failed to refresh allocated ports, " << e.what() << std::endl; + } + } + + std::this_thread::sleep_for(c_sock_diag_refresh_delay); + } +} + +int GnsPortTracker::ProcessSecCompNotification(seccomp_notif* notification) +{ + seccomp_notif notificationCopy = *notification; + m_request.post(notificationCopy); + return m_reply.get(); } -} // namespace +void GnsPortTracker::OnBindEvent(void* Data) +{ + const auto* event = static_cast(Data); + + in6_addr address = {}; + if (event->family == AF_INET) + { + address.s6_addr32[0] = event->addr4; + } + else if (event->family == AF_INET6) + { + memcpy(address.s6_addr32, event->addr6, sizeof(address.s6_addr32)); + } + else + { + return; + } + + PortAllocation allocation(event->port, event->family, event->protocol, address); + + // Check if this port is already tracked by the seccomp path (non-zero bind). + // If it is, this is not a port-0 bind — skip it. + { + std::lock_guard lock(m_portsMutex); + if (m_allocatedPorts.contains(allocation)) + { + return; + } + } + + // This is a port-0 bind that seccomp let through. Report it to the host. + const auto result = HandleRequest(allocation); + if (result == 0) + { + TrackPort(std::move(allocation)); + } + else + { + GNS_LOG_ERROR( + "Failed to register port-0 bind: family ({}) port ({}) protocol ({}), error {}", + allocation.Family, + allocation.Port, + allocation.Protocol, + result); + } +} + +void GnsPortTracker::Run() +{ + std::thread{std::bind(&GnsPortTracker::RunPortRefresh, this)}.detach(); + + // Start eBPF bind monitor on a background thread to handle port-0 bind resolution. + std::thread{[this]() { + try + { + auto* skel = bind_monitor_bpf__open_and_load(); + if (!skel) + { + LOG_ERROR("Failed to open/load bind monitor BPF program, {}", errno); + return; + } + + auto destroySkel = wil::scope_exit([&] { bind_monitor_bpf__destroy(skel); }); + + if (bind_monitor_bpf__attach(skel) != 0) + { + LOG_ERROR("Failed to attach bind monitor BPF program, {}", errno); + return; + } -GnsPortTracker::GnsPortTracker(std::shared_ptr hvSocketChannel, std::shared_ptr channelMutex) : - m_hvSocketChannel(std::move(hvSocketChannel)), m_channelMutex(std::move(channelMutex)) + auto onEvent = [](void* ctx, void* data, size_t dataSz) noexcept -> int { + try + { + static_cast(ctx)->OnBindEvent(data); + } + catch (...) + { + LOG_CAUGHT_EXCEPTION_MSG("Error processing bind monitor event"); + } + return 0; + }; + + auto* rb = ring_buffer__new(bpf_map__fd(skel->maps.events), onEvent, this, nullptr); + if (!rb) + { + LOG_ERROR("Failed to create bind monitor ring buffer, {}", errno); + return; + } + + auto destroyRb = wil::scope_exit([&] { ring_buffer__free(rb); }); + + GNS_LOG_INFO("BPF bind monitor for port-0 resolution attached and running"); + + for (;;) + { + int err = ring_buffer__poll(rb, -1); + if (err == -EINTR) + { + continue; + } + + if (err < 0) + { + LOG_ERROR("bind monitor ring_buffer__poll failed, {}", err); + return; + } + } + } + CATCH_LOG() + }}.detach(); + + auto future = std::make_optional(m_allocatedPortsRefresh.get_future()); + std::optional refreshResult; + + for (;;) + { + std::optional bindCall; + try + { + bindCall = ReadNextRequest(); + } + catch (const std::exception& e) + { + GNS_LOG_ERROR("Failed to read bind request, {}", e.what()); + } + + if (bindCall.has_value()) + { + int result = 0; + if (bindCall->Request.has_value()) + { + PortAllocation& allocationRequest = bindCall->Request.value(); + result = HandleRequest(allocationRequest); + if (result == 0) + { + TrackPort(allocationRequest); + GNS_LOG_INFO( + "Tracking bind call: family ({}) port ({}) protocol ({})", + allocationRequest.Family, + allocationRequest.Port, + allocationRequest.Protocol); + } + } + + try + { + CompleteRequest(bindCall->CallId, result); + } + catch (const std::exception& e) + { + GNS_LOG_ERROR("Failed to complete bind request, {}", e.what()); + } + + // Port-0 binds are now handled by the eBPF callback (OnBindEvent), + // so no deferred resolution is needed here. + } + + // If bindCall is empty, then the read() timed out. Look for any closed port + if (future.has_value() && future->wait_for(c_sock_diag_poll_timeout) == std::future_status::ready) + { + refreshResult.emplace(future->get()); + future.reset(); + m_allocatedPortsRefresh = {}; + + if (!bindCall.has_value()) + { + std::lock_guard lock(m_portsMutex); + OnRefreshAllocatedPorts(refreshResult->Ports, refreshResult->Timestamp); + } + } + + // Only look at bound ports if there's something to deallocate to avoid wasting cycles + if (refreshResult.has_value()) + { + std::lock_guard lock(m_portsMutex); + if (!m_allocatedPorts.empty()) + { + future = m_allocatedPortsRefresh.get_future(); + refreshResult->Resume(); + refreshResult.reset(); + } + } + } +} + +std::set GnsPortTracker::ListAllocatedPorts() { + std::set ports; + + inet_diag_req_v2 message{}; + message.sdiag_family = AF_INET; + message.sdiag_protocol = IPPROTO_TCP; + message.idiag_states = ~0; + + auto onMessage = [&](const NetlinkResponse& response) { + for (const auto& e : response.Messages(SOCK_DIAG_BY_FAMILY)) + { + const auto* payload = e.Payload(); + in6_addr address = {}; + + if (payload->idiag_family == AF_INET6) + { + static_assert(sizeof(address.s6_addr32) == 16); + static_assert(sizeof(address.s6_addr32) == sizeof(payload->id.idiag_src)); + memcpy(address.s6_addr32, payload->id.idiag_src, sizeof(address.s6_addr32)); + } + else + { + address.s6_addr32[0] = payload->id.idiag_src[0]; + } + + ports.emplace(ntohs(payload->id.idiag_sport), static_cast(payload->idiag_family), static_cast(message.sdiag_protocol), address); + } + }; + + { + auto transaction = m_channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); + transaction.Execute(onMessage); + } + + message.sdiag_family = AF_INET6; + { + auto transaction = m_channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); + transaction.Execute(onMessage); + } + + message.sdiag_protocol = IPPROTO_UDP; + message.sdiag_family = AF_INET; + { + auto transaction = m_channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); + transaction.Execute(onMessage); + } + + message.sdiag_family = AF_INET6; + { + auto transaction = m_channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); + transaction.Execute(onMessage); + } + + return ports; } -void GnsPortTracker::RequestPort(void* Data) +void GnsPortTracker::OnRefreshAllocatedPorts(const std::set& Ports, time_t Timestamp) { - const auto* Event = static_cast(Data); + // m_portsMutex must be held by caller. + // + // Because there's no way to get notified when the bind() call actually completes, it's possible + // that this method is called before the bind() completion and so the port allocation may not be visible yet. + // To avoid deallocating ports that simply haven't been done allocating yet, m_allocatedPorts stores a timeout + // that prevents deallocating the port unless: + // + // - The port has been seen to be allocated (if so, then the timeout is empty) + // - The timeout has expired + + for (auto it = m_allocatedPorts.begin(); it != m_allocatedPorts.end();) + { + if (Ports.find(it->first) == Ports.end()) + { + if (!it->second.has_value() || it->second.value() < Timestamp) + { + auto result = RequestPort(it->first, false); + if (result != 0) + { + std::cerr << "GnsPortTracker: Failed to deallocate port " << it->first << ", " << result << std::endl; + } + GNS_LOG_INFO( + "No longer tracking bind call: family ({}) port ({}) protocol ({})", + it->first.Family, + it->first.Port, + it->first.Protocol); + + it = m_allocatedPorts.erase(it); + continue; + } + } + else + { + it->second.reset(); // The port is known to be allocated, remove the timeout + } + + it++; + } +} + +int GnsPortTracker::RequestPort(const PortAllocation& Port, bool allocate) +{ LX_GNS_PORT_ALLOCATION_REQUEST request{}; request.Header.MessageType = LxGnsMessagePortMappingRequest; request.Header.MessageSize = sizeof(request); - request.Af = Event->family; - request.Protocol = Event->protocol; - request.Port = Event->port; - request.Allocate = Event->is_bind; - + request.Af = Port.Family; + request.Protocol = Port.Protocol; + request.Port = Port.Port; + request.Allocate = allocate; static_assert(sizeof(request.Address32) == 16); - if (Event->family == AF_INET) + static_assert(sizeof(request.Address32) == sizeof(Port.Address.s6_addr32)); + memcpy(request.Address32, Port.Address.s6_addr32, sizeof(request.Address32)); + + std::lock_guard lock(*m_channelMutex); + const auto& response = m_hvSocketChannel->Transaction(request); + + return response.Result; +} + +int GnsPortTracker::HandleRequest(const PortAllocation& Port) +{ { - request.Address32[0] = Event->addr4; + std::lock_guard lock(m_portsMutex); + if (m_allocatedPorts.contains(Port)) + { + GNS_LOG_INFO("Request for a port that's already reserved (family {}, port {}, protocol {})", Port.Family, Port.Port, Port.Protocol); + return 0; + } } - else + + // Ask the host for this port otherwise + const auto error = RequestPort(Port, true); + GNS_LOG_INFO( + "Requested the host for port allocation on port (family {}, port {}, protocol {}) - returned {}", Port.Family, Port.Port, Port.Protocol, error); + return error; +} + +std::optional GnsPortTracker::ReadNextRequest() +{ + auto request_value = m_request.try_get(c_bpf_poll_timeout); + if (!request_value.has_value()) { - memcpy(request.Address32, Event->addr6, sizeof(request.Address32)); + return {}; } - const auto& response = [&]() { - std::lock_guard lock(*m_channelMutex); - return m_hvSocketChannel->Transaction(request); - }(); + auto callInfo = request_value.value(); - GNS_LOG_INFO( - "Port {} request: family ({}) port ({}) protocol ({}) result ({})", - Event->is_bind ? "allocate" : "release", - Event->family, - Event->port, - Event->protocol, - response.Result); + try + { + return GetCallInfo(callInfo.id, callInfo.pid, callInfo.data.arch, callInfo.data.nr, gsl::make_span(callInfo.data.args)); + } + catch (const std::exception& e) + { + GNS_LOG_ERROR("Failed to read bind() call info with ID {} for pid {}, {}", callInfo.id, callInfo.pid, e.what()); + return {{{}, false, callInfo.id}}; + } } -void GnsPortTracker::Run() +std::optional GnsPortTracker::GetCallInfo( + uint64_t CallId, pid_t Pid, int Arch, int SysCallNumber, const gsl::span& Arguments) { - auto* skel = bind_monitor_bpf__open_and_load(); - THROW_LAST_ERROR_IF(!skel); + auto ParseSocket = [&](int Socket, size_t AddressPtr, size_t AddressLength) -> std::optional { + if (AddressLength < sizeof(sockaddr)) + { + return {{{}, false, CallId}}; + } - auto destroySkel = wil::scope_exit([&] { bind_monitor_bpf__destroy(skel); }); + auto networkNamespace = std::filesystem::read_symlink(std::format("/proc/{}/ns/net", Pid)).string(); + if (networkNamespace != m_networkNamespace) + { + GNS_LOG_INFO("Skipping bind() call for pid {} in network namespace {}", Pid, networkNamespace.c_str()); + return {{{}, false, CallId}}; + } - THROW_LAST_ERROR_IF(bind_monitor_bpf__attach(skel) != 0); + auto processMemory = m_seccompDispatcher->ReadProcessMemory(CallId, Pid, AddressPtr, AddressLength); + if (!processMemory.has_value()) + { + throw RuntimeErrorWithSourceLocation("Failed to read process memory"); + } - auto* rb = ring_buffer__new(bpf_map__fd(skel->maps.events), OnBindMonitorEvent, this, nullptr); - THROW_LAST_ERROR_IF(!rb); + sockaddr& address = *reinterpret_cast(processMemory->data()); - auto destroyRb = wil::scope_exit([&] { ring_buffer__free(rb); }); + if ((address.sa_family != AF_INET && address.sa_family != AF_INET6) || + (address.sa_family == AF_INET6 && AddressLength < sizeof(sockaddr_in6))) + { + return {{{}, false, CallId}}; + } - GNS_LOG_INFO("BPF bind monitor attached and running"); + static_assert(sizeof(sockaddr_in) <= sizeof(sockaddr)); - for (;;) + const auto* inAddr = reinterpret_cast(&address); + in_port_t port = ntohs(inAddr->sin_port); + if (port == 0) + { + // Port 0 means the kernel will assign an ephemeral port. + // The eBPF fexit handler (OnBindEvent) will resolve the actual port + // after bind() completes. Just let the call through. + return {{{}, true, CallId}}; + } + + in6_addr storedAddress = {}; + + if (address.sa_family == AF_INET) + { + storedAddress.s6_addr32[0] = inAddr->sin_addr.s_addr; + } + else + { + const auto* inAddr6 = reinterpret_cast(&address); + memcpy(storedAddress.s6_addr32, inAddr6->sin6_addr.s6_addr32, sizeof(storedAddress.s6_addr32)); + } + + const int protocol = GetSocketProtocol(Pid, Socket); + + if (!m_seccompDispatcher->ValidateCookie(CallId)) + { + throw RuntimeErrorWithSourceLocation(std::format("Invalid call id {}", CallId)); + } + + return {{{PortAllocation(port, address.sa_family, protocol, storedAddress)}, false, CallId}}; + }; +#ifdef __x86_64__ + if (Arch & __AUDIT_ARCH_64BIT) { - int err = ring_buffer__poll(rb, -1 /* block until event */); - if (err == -EINTR) + return ParseSocket(Arguments[0], Arguments[1], Arguments[2]); + } + else + { + if (Arguments[0] != SYS_BIND) { - continue; + return {{{}, false, CallId}}; + } + auto processMemory = m_seccompDispatcher->ReadProcessMemory(CallId, Pid, Arguments[1], sizeof(uint32_t) * 3); + if (!processMemory.has_value()) + { + throw RuntimeErrorWithSourceLocation("Failed to read process memory"); } - THROW_LAST_ERROR_IF(err < 0); + uint32_t* CopiedArguments = reinterpret_cast(processMemory->data()); + return ParseSocket(CopiedArguments[0], CopiedArguments[1], CopiedArguments[2]); } +#else + return ParseSocket(Arguments[0], Arguments[1], Arguments[2]); +#endif +} + +void GnsPortTracker::CompleteRequest(uint64_t id, int result) +{ + m_reply.post(result); +} + +int GnsPortTracker::GetSocketProtocol(int pid, int fd) +{ + const auto path = std::format("/proc/{}/fd/{}", pid, fd); + + // Because there's a race between the time where the buffer size is determined + // and the actual getxattr() call, retry until the buffer size is big enough + std::string protocol; + int result = -1; + do + { + int bufferSize = Syscall(getxattr, path.c_str(), "system.sockprotoname", nullptr, 0); + protocol.resize(std::max(0, bufferSize - 1)); + + result = getxattr(path.c_str(), "system.sockprotoname", protocol.data(), bufferSize); + } while (result < 0 && errno == ERANGE); + + if (result < 0) + { + throw RuntimeErrorWithSourceLocation(std::format("Failed to read protocol for socket: {}, {}", path, errno)); + } + + // In case the size of the attribute shrunk between the two getxattr calls + protocol.resize(std::max(0, result - 1)); + + if (protocol == "TCP" || protocol == "TCPv6") + { + return IPPROTO_TCP; + } + else if (protocol == "UDP" || protocol == "UDPv6") + { + return IPPROTO_UDP; + } + + throw RuntimeErrorWithSourceLocation(std::format("Unexpected IP socket protocol: {}", protocol)); +} + +void GnsPortTracker::TrackPort(PortAllocation allocation) +try +{ + // Use insert_or_assign so the deallocation timeout is refreshed if the same + // port key is already present (emplace would silently keep the old entry). + std::lock_guard lock(m_portsMutex); + m_allocatedPorts.insert_or_assign(std::move(allocation), std::make_optional(time(nullptr) + c_bind_timeout_seconds)); +} +catch (const std::exception& e) +{ + GNS_LOG_ERROR("Failed to track port allocation, {}", e.what()); +} + +std::ostream& operator<<(std::ostream& out, const GnsPortTracker::PortAllocation& entry) +{ + return out << "Port=" << entry.Port << ", Family=" << entry.Family << ", Protocol=" << entry.Protocol; } diff --git a/src/linux/init/GnsPortTracker.h b/src/linux/init/GnsPortTracker.h index 1ddbbe71f..2dfa2a8dd 100644 --- a/src/linux/init/GnsPortTracker.h +++ b/src/linux/init/GnsPortTracker.h @@ -1,24 +1,154 @@ // Copyright (C) Microsoft Corporation. All rights reserved. #pragma once -#include +#include +#include #include -#include "SocketChannel.h" +#include +#include +#include +#include +#include +#include +#include +#include #include "util.h" +#include +#include "waitablevalue.h" +#include "SecCompDispatcher.h" +#include "SocketChannel.h" class GnsPortTracker { public: - GnsPortTracker(std::shared_ptr hvSocketChannel, std::shared_ptr channelMutex); + GnsPortTracker( + std::shared_ptr hvSocketChannel, + NetlinkChannel&& netlinkChannel, + std::shared_ptr seccompDispatcher, + std::shared_ptr channelMutex); NON_COPYABLE(GnsPortTracker); NON_MOVABLE(GnsPortTracker); void Run(); - void RequestPort(void* data); + int ProcessSecCompNotification(seccomp_notif* notification); + + // Called from the eBPF ring buffer callback thread to handle port-0 bind resolution. + void OnBindEvent(void* data); + + struct PortAllocation + { + in6_addr Address = {}; + std::uint16_t Port = {}; + int Family = {}; + int Protocol = {}; + + PortAllocation(PortAllocation&&) = default; + PortAllocation(const PortAllocation&) = default; + + PortAllocation& operator=(PortAllocation&&) = default; + PortAllocation& operator=(const PortAllocation&) = default; + + PortAllocation(std::uint16_t Port, int Family, int Protocol, in6_addr& Address) : + Port(Port), Family(Family), Protocol(Protocol) + { + memcpy(this->Address.s6_addr32, Address.s6_addr32, sizeof(this->Address.s6_addr32)); + } + + bool operator<(const PortAllocation& other) const + { + if (Port < other.Port) + { + return true; + } + else if (Port > other.Port) + { + return false; + } + + if (Family < other.Family) + { + return true; + } + else if (Family > other.Family) + { + return false; + } + + if (Protocol < other.Protocol) + { + return true; + } + else if (Protocol > other.Protocol) + { + return false; + } + + static_assert(sizeof(Address.s6_addr32) == 16); + if (int res = memcmp(Address.s6_addr32, other.Address.s6_addr32, sizeof(Address.s6_addr32)); res < 0) + { + return true; + } + else if (res > 0) + { + return false; + } + + return false; + } + }; + + struct BindCall + { + std::optional Request; + bool IsPortZeroBind = false; + std::uint64_t CallId; + }; + + struct PortRefreshResult + { + std::set Ports; + time_t Timestamp; + std::function Resume; + }; private: + void OnRefreshAllocatedPorts(const std::set& Ports, time_t Timestamp); + + void RunPortRefresh(); + + std::set ListAllocatedPorts(); + + std::optional ReadNextRequest(); + + std::optional GetCallInfo(uint64_t CallId, pid_t Pid, int Arch, int SysCallNumber, const gsl::span& Arguments); + + int RequestPort(const PortAllocation& Port, bool Allocate); + + int HandleRequest(const PortAllocation& Request); + + void CompleteRequest(uint64_t Id, int Result); + + static int GetSocketProtocol(int Pid, int Fd); + + void TrackPort(PortAllocation allocation); + + // Protects m_allocatedPorts for concurrent access from the main loop and the eBPF callback thread. + std::mutex m_portsMutex; + std::map> m_allocatedPorts; + std::shared_ptr m_hvSocketChannel; std::shared_ptr m_channelMutex; + NetlinkChannel m_channel; + std::promise m_allocatedPortsRefresh; + + WaitableValue m_request; + WaitableValue m_reply; + + std::shared_ptr m_seccompDispatcher; + + std::string m_networkNamespace; }; + +std::ostream& operator<<(std::ostream& out, const GnsPortTracker::PortAllocation& portAllocation); diff --git a/src/linux/init/localhost.cpp b/src/linux/init/localhost.cpp index 38e5d3218..3d7e3c3db 100644 --- a/src/linux/init/localhost.cpp +++ b/src/linux/init/localhost.cpp @@ -25,6 +25,8 @@ #include "SecCompDispatcher.h" #include "seccomp_defs.h" #include "CommandLine.h" +#include "NetlinkChannel.h" +#include "NetlinkTransactionError.h" #include "listen_monitor.h" #include "listen_monitor.skel.h" @@ -349,6 +351,8 @@ int RunPortTracker(int Argc, char** Argv) " fd" " [" INIT_BPF_FD_ARG " fd]" + " [" INIT_NETLINK_FD_ARG + " fd]" " [" INIT_PORT_TRACKER_LOCALHOST_RELAY " fd]\n"; // This is only supported on VM mode. @@ -362,11 +366,13 @@ int RunPortTracker(int Argc, char** Argv) int BpfFd = -1; int PortTrackerFd = -1; + int NetlinkSocketFd = -1; int GuestRelayFd = -1; ArgumentParser parser(Argc, Argv); parser.AddArgument(Integer{BpfFd}, INIT_BPF_FD_ARG); parser.AddArgument(Integer{PortTrackerFd}, INIT_PORT_TRACKER_FD_ARG); + parser.AddArgument(Integer{NetlinkSocketFd}, INIT_NETLINK_FD_ARG); parser.AddArgument(Integer{GuestRelayFd}, INIT_PORT_TRACKER_LOCALHOST_RELAY); try @@ -379,7 +385,7 @@ int RunPortTracker(int Argc, char** Argv) return 1; } - const bool synchronousMode = BpfFd != -1; + const bool synchronousMode = BpfFd != -1 && NetlinkSocketFd != -1; const bool localhostRelay = GuestRelayFd != -1; auto hvSocketChannel = std::make_shared(wil::unique_fd{PortTrackerFd}, "localhost"); @@ -400,14 +406,31 @@ int RunPortTracker(int Argc, char** Argv) if (!synchronousMode) { - std::cerr << "synchronous mode requires --bpf-fd\n"; + std::cerr << "either both or none of --bpf-fd and --netlink-socket can be passed\n"; return 1; } + auto channel = NetlinkChannel::FromFd(NetlinkSocketFd); + auto channelMutex = std::make_shared(); auto seccompDispatcher = std::make_shared(BpfFd); + GnsPortTracker portTracker(hvSocketChannel, std::move(channel), seccompDispatcher, channelMutex); + + seccompDispatcher->RegisterHandler( + __NR_bind, [&portTracker](seccomp_notif* notification) { return portTracker.ProcessSecCompNotification(notification); }); + +#ifdef __x86_64__ + seccompDispatcher->RegisterHandler(I386_NR_socketcall, [&portTracker](seccomp_notif* notification) { + return portTracker.ProcessSecCompNotification(notification); + }); +#else + seccompDispatcher->RegisterHandler(ARMV7_NR_bind, [&portTracker](seccomp_notif* notification) { + return portTracker.ProcessSecCompNotification(notification); + }); +#endif + seccompDispatcher->RegisterHandler(__NR_ioctl, [hvSocketChannel, seccompDispatcher, channelMutex](auto notification) -> int { LX_GNS_TUN_BRIDGE_REQUEST request{}; request.Header.MessageType = LxGnsMessageIfStateChangeRequest; @@ -429,7 +452,6 @@ int RunPortTracker(int Argc, char** Argv) return reply.Result; }); - GnsPortTracker portTracker(hvSocketChannel, channelMutex); try { portTracker.Run(); diff --git a/src/linux/init/main.cpp b/src/linux/init/main.cpp index da3d54812..c5c8f4d9a 100644 --- a/src/linux/init/main.cpp +++ b/src/linux/init/main.cpp @@ -1352,12 +1352,25 @@ Return Value: return; } + wil::unique_fd NetlinkSocket{}; wil::unique_fd BpfFd{}; wil::unique_fd GuestRelayFd{}; switch (Type) { case LxMiniInitPortTrackerTypeMirrored: { + + // + // Create a netlink socket before registering the bpf filter so creation of the socket + // does not trigger the filter. + // + + NetlinkSocket = CreateNetlinkSocket(); + if (!NetlinkSocket) + { + return; + } + BpfFd = RegisterSeccompHook(); if (!BpfFd) { @@ -1385,6 +1398,7 @@ Return Value: UtilCreateChildProcess( "PortTracker", [PortTrackerFd = std::move(PortTrackerFd), + NetlinkSocket = std::move(NetlinkSocket), BpfFd = std::move(BpfFd), GuestRelayFd = std::move(GuestRelayFd)]() { execl( @@ -1394,6 +1408,8 @@ Return Value: std::format("{}", PortTrackerFd.get()).c_str(), INIT_BPF_FD_ARG, std::format("{}", BpfFd.get()).c_str(), + INIT_NETLINK_FD_ARG, + std::format("{}", NetlinkSocket.get()).c_str(), INIT_PORT_TRACKER_LOCALHOST_RELAY, std::format("{}", GuestRelayFd.get()).c_str(), NULL); @@ -3524,7 +3540,7 @@ wil::unique_fd RegisterSeccompHook() Routine Description: - Register a seccomp notification for ioctl(*, TUNSETIFF, *) calls. + Register a seccomp notification for bind() & ioctl(*, SIOCSIFFLAGS, *) calls. Arguments: @@ -3547,9 +3563,12 @@ Return Value: // 64bit: // If syscall_arch & __AUDIT_ARCH_64BIT then continue else goto :32bit BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_arch), - BPF_JUMP(BPF_JMP + BPF_JSET + BPF_K, __AUDIT_ARCH_64BIT, 0, 5), - // If syscall_nr == __NR_ioctl then continue else goto allow: + // For now, notify on all non-native arch + BPF_JUMP(BPF_JMP + BPF_JSET + BPF_K, __AUDIT_ARCH_64BIT, 0, 7), + // If syscall_nr == __NR_bind then goto user_notify: else continue BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_nr), + BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, __NR_bind, 3, 0), + // if (syscall_nr == __NR_ioctl) then continue else goto allow: BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, __NR_ioctl, 0, 3), // if (syscall arg1 == SIOCSIFFLAGS) goto user_notify else goto allow: BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_arg(1)), @@ -3561,10 +3580,34 @@ Return Value: // return SECCOMP_RET_ALLOW; BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ALLOW), - // 32bit: no ioctl interception needed for 32-bit processes. + // Note: 32bit on x86_64 uses the __NR_socketcall with the first argument + // set to SYS_BIND to make bind system call. +#ifdef __x86_64__ + // 32bit: + // If syscall_nr == __NR_socketcall then continue else goto allow: + BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_nr), + BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, I386_NR_socketcall, 0, 3), + // if syscall arg0 == SYS_BIND then goto user_notify: else goto allow: + BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_arg(0)), + BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, SYS_BIND, 0, 1), + // user_notify: + // return SECCOMP_RET_USER_NOTIF; + BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_USER_NOTIF), + // allow: + // return SECCOMP_RET_ALLOW; + BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ALLOW), +#else + // 32bit: + // If syscall_nr == __NR_bind then goto user_notify: else goto allow: + BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_nr), + BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, ARMV7_NR_bind, 0, 1), + // user_notify: + // return SECCOMP_RET_USER_NOTIF; + BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_USER_NOTIF), // allow: // return SECCOMP_RET_ALLOW; BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ALLOW), +#endif }; struct sock_fprog Prog = {