diff --git a/CMakeLists.txt b/CMakeLists.txt index 2834728e3..b840aa577 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,8 @@ find_nuget_package(Microsoft.WSL.DeviceHost WSL_DEVICE_HOST /build/native) find_nuget_package(Microsoft.WSL.Kernel KERNEL /build/native) find_nuget_package(Microsoft.WSL.bsdtar BSDTARD /build/native/bin) find_nuget_package(Microsoft.WSL.LinuxSdk LINUXSDK /) +#find_nuget_package(Microsoft.WSL.BPF WSL_BPF /) +set(WSL_BPF_SOURCE_DIR "${CMAKE_SOURCE_DIR}/packages/Microsoft.WSL.BPF.1.0.0") find_nuget_package(Microsoft.WSL.TestDistro TEST_DISTRO /) find_nuget_package(Microsoft.WSLg WSLG /build/native/bin) find_nuget_package(vswhere VSWHERE /tools) diff --git a/packages.config b/packages.config index 22cc87bb7..a71b31702 100644 --- a/packages.config +++ b/packages.config @@ -16,6 +16,7 @@ + diff --git a/src/linux/bpf/CMakeLists.txt b/src/linux/bpf/CMakeLists.txt new file mode 100644 index 000000000..e069def24 --- /dev/null +++ b/src/linux/bpf/CMakeLists.txt @@ -0,0 +1,124 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# +# CMakeLists.txt for building BPF skeleton headers. +# This is intended to be run on a Linux build machine, NOT as part of the main Windows build. +# +# Required parameters: +# -DKERNEL_IMAGE_PATH= Path to the WSL kernel binary (bzImage/Image) with embedded BTF +# -DBPFTOOL_PATH= Path to the bpftool binary +# +# Optional parameters: +# -DCLANG_BPF= Path to clang (default: clang) +# -DOUTPUT_DIR= Output directory for generated headers (default: ${CMAKE_BINARY_DIR}/output) +# +# Usage: +# cmake -S . -B build -DKERNEL_IMAGE_PATH=/path/to/bzImage -DBPFTOOL_PATH=/usr/sbin/bpftool +# cmake --build build +# +# Output: +# ${OUTPUT_DIR}/bind_monitor.skel.h - BPF skeleton header (embed in NuGet) +# ${OUTPUT_DIR}/bind_monitor.h - Shared event struct header (embed in NuGet) + +cmake_minimum_required(VERSION 3.20) +project(wsl-bpf LANGUAGES NONE) + +# Validate required parameters +if(NOT DEFINED KERNEL_IMAGE_PATH) + message(FATAL_ERROR "KERNEL_IMAGE_PATH is required. Pass -DKERNEL_IMAGE_PATH=") +endif() + +if(NOT EXISTS "${KERNEL_IMAGE_PATH}") + message(FATAL_ERROR "Kernel image not found: ${KERNEL_IMAGE_PATH}") +endif() + +if(NOT DEFINED BPFTOOL_PATH) + message(FATAL_ERROR "BPFTOOL_PATH is required. Pass -DBPFTOOL_PATH=") +endif() + +if(NOT EXISTS "${BPFTOOL_PATH}") + message(FATAL_ERROR "bpftool not found: ${BPFTOOL_PATH}") +endif() + +# Optional parameters +if(NOT DEFINED CLANG_BPF) + set(CLANG_BPF "clang") +endif() + +if(NOT DEFINED OUTPUT_DIR) + set(OUTPUT_DIR "${CMAKE_BINARY_DIR}/output") +endif() + +file(MAKE_DIRECTORY "${OUTPUT_DIR}") + +set(SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}") +set(VMLINUX_H "${CMAKE_BINARY_DIR}/vmlinux.h") +set(BPF_OBJ "${CMAKE_BINARY_DIR}/bind_monitor.bpf.o") +set(SKEL_H "${OUTPUT_DIR}/bind_monitor.skel.h") +set(EVENT_H "${OUTPUT_DIR}/bind_monitor.h") + +# Detect target architecture for BPF +execute_process( + COMMAND uname -m + OUTPUT_VARIABLE HOST_ARCH + OUTPUT_STRIP_TRAILING_WHITESPACE) + +if(HOST_ARCH STREQUAL "x86_64") + set(BPF_TARGET_ARCH "x86") +elseif(HOST_ARCH STREQUAL "aarch64") + set(BPF_TARGET_ARCH "arm64") +else() + message(FATAL_ERROR "Unsupported architecture: ${HOST_ARCH}") +endif() + +# Locate libbpf headers (needed for BPF compilation) +find_path(LIBBPF_INCLUDE_DIR "bpf/bpf_helpers.h" + PATHS /usr/include /usr/local/include + REQUIRED) + +# Step 1: Generate vmlinux.h from kernel image +add_custom_command( + OUTPUT "${VMLINUX_H}" + COMMAND bash "${SRC_DIR}/generate-vmlinux-header.sh" "${KERNEL_IMAGE_PATH}" "${VMLINUX_H}" + DEPENDS "${KERNEL_IMAGE_PATH}" "${SRC_DIR}/generate-vmlinux-header.sh" + COMMENT "Generating vmlinux.h from kernel image" + VERBATIM) + +# Step 2: Compile BPF program +add_custom_command( + OUTPUT "${BPF_OBJ}" + COMMAND "${CLANG_BPF}" -O2 -g -target bpf + -D__TARGET_ARCH_${BPF_TARGET_ARCH} + -I "${CMAKE_BINARY_DIR}" + -I "${LIBBPF_INCLUDE_DIR}" + -I "${SRC_DIR}" + -c "${SRC_DIR}/bind_monitor.bpf.c" + -o "${BPF_OBJ}" + DEPENDS "${SRC_DIR}/bind_monitor.bpf.c" "${SRC_DIR}/bind_monitor.h" "${VMLINUX_H}" + COMMENT "Compiling bind_monitor.bpf.c" + VERBATIM) + +# Step 3: Generate skeleton header +add_custom_command( + OUTPUT "${SKEL_H}" + COMMAND "${BPFTOOL_PATH}" gen skeleton "${BPF_OBJ}" > "${SKEL_H}" + DEPENDS "${BPF_OBJ}" + COMMENT "Generating bind_monitor.skel.h" + VERBATIM) + +# Step 4: Copy shared header to output +add_custom_command( + OUTPUT "${EVENT_H}" + COMMAND ${CMAKE_COMMAND} -E copy "${SRC_DIR}/bind_monitor.h" "${EVENT_H}" + DEPENDS "${SRC_DIR}/bind_monitor.h" + COMMENT "Copying bind_monitor.h to output" + VERBATIM) + +# Main target +add_custom_target(bpf_headers ALL DEPENDS "${SKEL_H}" "${EVENT_H}") + +message(STATUS "BPF build configuration:") +message(STATUS " Kernel image: ${KERNEL_IMAGE_PATH}") +message(STATUS " bpftool: ${BPFTOOL_PATH}") +message(STATUS " clang: ${CLANG_BPF}") +message(STATUS " Target arch: ${BPF_TARGET_ARCH}") +message(STATUS " Output dir: ${OUTPUT_DIR}") diff --git a/src/linux/bpf/bind_monitor.bpf.c b/src/linux/bpf/bind_monitor.bpf.c new file mode 100644 index 000000000..4a2eaa586 --- /dev/null +++ b/src/linux/bpf/bind_monitor.bpf.c @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include "vmlinux.h" +#include +#include +#include +#include +#include "bind_monitor.h" + +struct { + __uint(type, BPF_MAP_TYPE_RINGBUF); + __uint(max_entries, BIND_MONITOR_RINGBUF_SIZE); +} events SEC(".maps"); + +static __always_inline int emit_bind_event(struct sock *sk, int ret) +{ + struct bind_event *e; + + if (!sk || ret != 0) + return 0; + + __u16 family = BPF_CORE_READ(sk, __sk_common.skc_family); + if (family != AF_INET && family != AF_INET6) + return 0; + + __u16 protocol = BPF_CORE_READ(sk, sk_protocol); + if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP) + return 0; + + e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); + if (!e) + return 0; + + e->family = family; + e->protocol = protocol; + e->port = BPF_CORE_READ(sk, __sk_common.skc_num); + e->pad = 0; + + if (family == AF_INET) + { + struct inet_sock *inet = (struct inet_sock *)sk; + e->addr4 = BPF_CORE_READ(inet, inet_saddr); + __builtin_memset(e->addr6, 0, sizeof(e->addr6)); + } + else + { + struct ipv6_pinfo *pinet6; + e->addr4 = 0; + pinet6 = BPF_CORE_READ((struct inet_sock *)sk, pinet6); + if (pinet6) + BPF_CORE_READ_INTO(e->addr6, pinet6, saddr.in6_u.u6_addr8); + else + __builtin_memset(e->addr6, 0, sizeof(e->addr6)); + } + + bpf_ringbuf_submit(e, 0); + return 0; +} + +// fexit/inet_bind: called after inet_bind() returns. +SEC("fexit/inet_bind") +int BPF_PROG(fexit_inet_bind, struct socket *sock, struct sockaddr *uaddr, + int addr_len, int ret) +{ + struct sock *sk = BPF_CORE_READ(sock, sk); + return emit_bind_event(sk, ret); +} + +// fexit/inet6_bind: called after inet6_bind() returns. +SEC("fexit/inet6_bind") +int BPF_PROG(fexit_inet6_bind, struct socket *sock, struct sockaddr *uaddr, + int addr_len, int ret) +{ + struct sock *sk = BPF_CORE_READ(sock, sk); + return emit_bind_event(sk, ret); +} + +char LICENSE[] SEC("license") = "GPL"; diff --git a/src/linux/bpf/bind_monitor.h b/src/linux/bpf/bind_monitor.h new file mode 100644 index 000000000..6138e4a48 --- /dev/null +++ b/src/linux/bpf/bind_monitor.h @@ -0,0 +1,35 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. + +#pragma once + +#ifndef __u8 +typedef unsigned char __u8; +typedef unsigned short __u16; +typedef unsigned int __u32; +typedef unsigned long long __u64; +#endif + +#ifndef AF_INET +#define AF_INET 2 +#endif +#ifndef AF_INET6 +#define AF_INET6 10 +#endif + +#ifndef IPPROTO_TCP +#define IPPROTO_TCP 6 +#endif +#ifndef IPPROTO_UDP +#define IPPROTO_UDP 17 +#endif + +#define BIND_MONITOR_RINGBUF_SIZE (1 << 16) /* 64 KB */ + +struct bind_event { + __u32 family; /* AF_INET or AF_INET6 */ + __u32 protocol; /* IPPROTO_TCP or IPPROTO_UDP */ + __u16 port; /* host byte order */ + __u16 pad; + __u32 addr4; /* IPv4 address (network byte order) */ + __u8 addr6[16]; /* IPv6 address */ +}; diff --git a/src/linux/bpf/generate-vmlinux-header.sh b/src/linux/bpf/generate-vmlinux-header.sh new file mode 100644 index 000000000..5c7a73a74 --- /dev/null +++ b/src/linux/bpf/generate-vmlinux-header.sh @@ -0,0 +1,111 @@ +#!/bin/bash +set -e + +if [ $# -lt 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +BZIMAGE="$1" +OUTPUT="$2" +BPFTOOL="${BPFTOOL:-bpftool}" + +TMPDIR=$(mktemp -d) +trap 'rm -rf "$TMPDIR"' EXIT + +VMLINUX="$TMPDIR/vmlinux" + +echo "Extracting vmlinux from $BZIMAGE..." + +# First, try bpftool directly (works if input is already an ELF with BTF) +if "$BPFTOOL" btf dump file "$BZIMAGE" format c > "$OUTPUT" 2>/dev/null; then + LINES=$(wc -l < "$OUTPUT") + if [ "$LINES" -gt 100 ]; then + echo "Done. Generated $OUTPUT ($LINES lines)" + exit 0 + fi +fi + +# Try to find an ELF embedded in the image (e.g., ARM64 Image) +ELF_OFFSET=$(binwalk -y elf "$BZIMAGE" 2>/dev/null | grep -oP '^\d+' | head -1) || true +if [ -n "$ELF_OFFSET" ]; then + echo "Found ELF at offset $ELF_OFFSET" + tail -c +$((ELF_OFFSET + 1)) "$BZIMAGE" > "$VMLINUX" + if file "$VMLINUX" | grep -q 'ELF' && "$BPFTOOL" btf dump file "$VMLINUX" format c > "$OUTPUT" 2>/dev/null; then + LINES=$(wc -l < "$OUTPUT") + if [ "$LINES" -gt 100 ]; then + echo "Done. Generated $OUTPUT ($LINES lines)" + exit 0 + fi + fi +fi + +# Try to find raw BTF data in the image (ARM64 Image stores BTF as raw data) +BTF_RAW="$TMPDIR/btf.raw" +BTF_OFFSET=$(python3 -c " +import struct, sys +data = open(sys.argv[1], 'rb').read() +magic = b'\x9f\xeb' +idx = 0 +while True: + idx = data.find(magic, idx) + if idx == -1: break + if data[idx+2] == 1: # version 1 + hdr_len = struct.unpack_from(' 1000: + print(f'{idx} {total}') + break + idx += 1 +" "$BZIMAGE" 2>/dev/null) || true + +if [ -n "$BTF_OFFSET" ]; then + BTF_OFF=$(echo "$BTF_OFFSET" | awk '{print $1}') + BTF_SIZE=$(echo "$BTF_OFFSET" | awk '{print $2}') + echo "Found raw BTF at offset $BTF_OFF ($BTF_SIZE bytes)" + tail -c +$((BTF_OFF + 1)) "$BZIMAGE" | head -c "$BTF_SIZE" > "$BTF_RAW" + if "$BPFTOOL" btf dump file "$BTF_RAW" format c > "$OUTPUT" 2>/dev/null; then + LINES=$(wc -l < "$OUTPUT") + if [ "$LINES" -gt 100 ]; then + echo "Done. Generated $OUTPUT ($LINES lines)" + exit 0 + fi + fi +fi + +# Find gzip offset and decompress to get the vmlinux ELF +GZIP_OFFSET=$(binwalk -y gzip "$BZIMAGE" 2>/dev/null | grep -oP '^\d+' | head -1) || true + +if [ -z "$GZIP_OFFSET" ]; then + echo "Error: no gzip or ELF payload found in $BZIMAGE" >&2 + exit 1 +fi + +echo "Found gzip payload at offset $GZIP_OFFSET" +tail -c +$((GZIP_OFFSET + 1)) "$BZIMAGE" | zcat > "$VMLINUX" 2>/dev/null || true + +if ! file "$VMLINUX" | grep -q 'ELF'; then + # The gzip payload might contain another layer; try to find ELF inside + INNER="$TMPDIR/inner" + tail -c +$((GZIP_OFFSET + 1)) "$BZIMAGE" | zcat 2>/dev/null > "$INNER" || true + + # Search for ELF magic in decompressed data + ELF_OFFSET=$(grep -a -b -o -P '\x7fELF' "$INNER" 2>/dev/null | head -1 | cut -d: -f1) || true + if [ -n "$ELF_OFFSET" ]; then + tail -c +$((ELF_OFFSET + 1)) "$INNER" > "$VMLINUX" + fi +fi + +if ! file "$VMLINUX" | grep -q 'ELF'; then + echo "Error: could not extract a valid ELF vmlinux from $BZIMAGE" >&2 + exit 1 +fi + +echo "Extracted vmlinux: $(file "$VMLINUX")" +echo "Generating vmlinux.h with bpftool..." +"$BPFTOOL" btf dump file "$VMLINUX" format c > "$OUTPUT" + +LINES=$(wc -l < "$OUTPUT") +echo "Done. Generated $OUTPUT ($LINES lines)" diff --git a/src/linux/bpf/listen_monitor.bpf.c b/src/linux/bpf/listen_monitor.bpf.c new file mode 100644 index 000000000..b89bada00 --- /dev/null +++ b/src/linux/bpf/listen_monitor.bpf.c @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include "vmlinux.h" +#include +#include +#include +#include +#include "listen_monitor.h" + +struct { + __uint(type, BPF_MAP_TYPE_RINGBUF); + __uint(max_entries, LISTEN_MONITOR_RINGBUF_SIZE); +} events SEC(".maps"); + +// Store listen info at listen() time so release can emit correct values +// even if the kernel has already cleared the socket fields. +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __type(key, __u64); // sock pointer as key + __type(value, struct listen_event); // snapshot from listen time + __uint(max_entries, 65536); +} listening_sockets SEC(".maps"); + +static __always_inline int read_sock_info(struct sock *sk, struct listen_event *info) +{ + info->family = BPF_CORE_READ(sk, __sk_common.skc_family); + if (info->family != AF_INET && info->family != AF_INET6) + return -1; + + info->port = BPF_CORE_READ(sk, __sk_common.skc_num); + + if (info->family == AF_INET) + { + struct inet_sock *inet = (struct inet_sock *)sk; + info->addr4 = BPF_CORE_READ(inet, inet_saddr); + __builtin_memset(info->addr6, 0, sizeof(info->addr6)); + } + else + { + struct ipv6_pinfo *pinet6; + info->addr4 = 0; + pinet6 = BPF_CORE_READ((struct inet_sock *)sk, pinet6); + if (pinet6) + BPF_CORE_READ_INTO(info->addr6, pinet6, saddr.in6_u.u6_addr8); + else + __builtin_memset(info->addr6, 0, sizeof(info->addr6)); + } + + return 0; +} + +// fexit/inet_listen: called after inet_listen() returns. +// inet_listen handles both IPv4 and IPv6. +// int inet_listen(struct socket *sock, int backlog) +SEC("fexit/inet_listen") +int BPF_PROG(fexit_inet_listen, struct socket *sock, int backlog, int ret) +{ + struct listen_event *e; + struct listen_event info = {}; + __u64 sk_key; + struct sock *sk; + + if (ret != 0) + return 0; + + sk = BPF_CORE_READ(sock, sk); + if (!sk) + return 0; + + sk_key = (__u64)sk; + + if (read_sock_info(sk, &info) < 0) + return 0; + + info.is_listen = 1; + + bpf_map_update_elem(&listening_sockets, &sk_key, &info, BPF_ANY); + + e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); + if (!e) + return 0; + + *e = info; + bpf_ringbuf_submit(e, 0); + return 0; +} + +// fentry/inet_release: called when a socket is being closed. +// void inet_release(struct socket *sock) +SEC("fentry/inet_release") +int BPF_PROG(fentry_inet_release, struct socket *sock) +{ + struct listen_event *e; + struct listen_event *info; + struct sock *sk; + __u64 sk_key; + + sk = BPF_CORE_READ(sock, sk); + if (!sk) + return 0; + + sk_key = (__u64)sk; + + info = bpf_map_lookup_elem(&listening_sockets, &sk_key); + if (!info) + return 0; + + e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); + if (!e) + { + bpf_map_delete_elem(&listening_sockets, &sk_key); + return 0; + } + + *e = *info; + e->is_listen = 0; + + bpf_map_delete_elem(&listening_sockets, &sk_key); + bpf_ringbuf_submit(e, 0); + return 0; +} + +char LICENSE[] SEC("license") = "GPL"; diff --git a/src/linux/bpf/listen_monitor.h b/src/linux/bpf/listen_monitor.h new file mode 100644 index 000000000..58e684dcd --- /dev/null +++ b/src/linux/bpf/listen_monitor.h @@ -0,0 +1,28 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. + +#pragma once + +#ifndef __u8 +typedef unsigned char __u8; +typedef unsigned short __u16; +typedef unsigned int __u32; +typedef unsigned long long __u64; +#endif + +#ifndef AF_INET +#define AF_INET 2 +#endif +#ifndef AF_INET6 +#define AF_INET6 10 +#endif + +#define LISTEN_MONITOR_RINGBUF_SIZE (1 << 16) /* 64 KB */ + +struct listen_event { + __u32 family; /* AF_INET or AF_INET6 */ + __u16 port; /* host byte order */ + __u8 is_listen; /* 1 = started listening, 0 = stopped listening */ + __u8 pad; + __u32 addr4; /* IPv4 address (network byte order) */ + __u8 addr6[16]; /* IPv6 address */ +}; diff --git a/src/linux/init/CMakeLists.txt b/src/linux/init/CMakeLists.txt index d5dee8005..ef19f6812 100644 --- a/src/linux/init/CMakeLists.txt +++ b/src/linux/init/CMakeLists.txt @@ -46,7 +46,14 @@ set(HEADERS wslpath.h) set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${CMAKE_CURRENT_LIST_DIR}/../netlinkutil") +set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${WSL_BPF_SOURCE_DIR}/bpf/${TARGET_PLATFORM}/include") +set(LINUX_CXXFLAGS ${LINUX_CXXFLAGS} -I "${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/include") set(INIT_LIBRARIES ${COMMON_LINUX_LINK_LIBRARIES} netlinkutil plan9 mountutil configfile) +set(INIT_EXTRA_LIBS + ${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/lib/libbpf.a + ${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/lib/libelf.a + ${WSL_BPF_SOURCE_DIR}/libbpf/${TARGET_PLATFORM}/lib/libz.a) +set(LINUX_LDFLAGS ${LINUX_LDFLAGS} ${INIT_EXTRA_LIBS}) add_linux_executable(init "${SOURCES}" "${HEADERS};${COMMON_LINUX_HEADERS}" "${INIT_LIBRARIES}") add_dependencies(init localization) diff --git a/src/linux/init/GnsPortTracker.cpp b/src/linux/init/GnsPortTracker.cpp index 669582a14..dfa0ca145 100644 --- a/src/linux/init/GnsPortTracker.cpp +++ b/src/linux/init/GnsPortTracker.cpp @@ -3,16 +3,19 @@ #include #include #include -#include /* Definition of AUDIT_* constants */ +#include #include #include #include #include #include -#include "common.h" // Needs to be included before sal.h before of __reserved macro +#include +#include "common.h" #include "NetlinkTransactionError.h" #include "GnsPortTracker.h" #include "lxinitshared.h" +#include "bind_monitor.h" +#include "bind_monitor.skel.h" constexpr size_t c_bind_timeout_seconds = 60; constexpr auto c_sock_diag_refresh_delay = std::chrono::milliseconds(500); @@ -20,8 +23,14 @@ constexpr auto c_sock_diag_poll_timeout = std::chrono::milliseconds(10); constexpr auto c_bpf_poll_timeout = std::chrono::milliseconds(500); GnsPortTracker::GnsPortTracker( - std::shared_ptr hvSocketChannel, NetlinkChannel&& netlinkChannel, std::shared_ptr seccompDispatcher) : - m_hvSocketChannel(std::move(hvSocketChannel)), m_channel(std::move(netlinkChannel)), m_seccompDispatcher(seccompDispatcher) + std::shared_ptr hvSocketChannel, + NetlinkChannel&& netlinkChannel, + std::shared_ptr seccompDispatcher, + std::shared_ptr channelMutex) : + m_hvSocketChannel(std::move(hvSocketChannel)), + m_channelMutex(std::move(channelMutex)), + m_channel(std::move(netlinkChannel)), + m_seccompDispatcher(seccompDispatcher) { m_networkNamespace = std::filesystem::read_symlink("/proc/self/ns/net").string(); } @@ -65,15 +74,116 @@ int GnsPortTracker::ProcessSecCompNotification(seccomp_notif* notification) return m_reply.get(); } -void GnsPortTracker::Run() +void GnsPortTracker::OnBindEvent(void* Data) { - // This method consumes seccomp notifications and allows / disallows port allocations - // depending on wsl core's response. - // After dealing with a notification it also looks at the bound ports list to check - // for port deallocation + const auto* event = static_cast(Data); + + in6_addr address = {}; + if (event->family == AF_INET) + { + address.s6_addr32[0] = event->addr4; + } + else if (event->family == AF_INET6) + { + memcpy(address.s6_addr32, event->addr6, sizeof(address.s6_addr32)); + } + else + { + return; + } + + PortAllocation allocation(event->port, event->family, event->protocol, address); + + // Check if this port is already tracked by the seccomp path (non-zero bind). + // If it is, this is not a port-0 bind — skip it. + { + std::lock_guard lock(m_portsMutex); + if (m_allocatedPorts.contains(allocation)) + { + return; + } + } + + // This is a port-0 bind that seccomp let through. Report it to the host. + const auto result = HandleRequest(allocation); + if (result == 0) + { + TrackPort(std::move(allocation)); + } + else + { + GNS_LOG_ERROR( + "Failed to register port-0 bind: family ({}) port ({}) protocol ({}), error {}", + allocation.Family, + allocation.Port, + allocation.Protocol, + result); + } +} +void GnsPortTracker::Run() +{ std::thread{std::bind(&GnsPortTracker::RunPortRefresh, this)}.detach(); - std::thread{std::bind(&GnsPortTracker::RunDeferredResolve, this)}.detach(); + + // Start eBPF bind monitor on a background thread to handle port-0 bind resolution. + std::thread{[this]() { + try + { + auto* skel = bind_monitor_bpf__open_and_load(); + if (!skel) + { + LOG_ERROR("Failed to open/load bind monitor BPF program, {}", errno); + return; + } + + auto destroySkel = wil::scope_exit([&] { bind_monitor_bpf__destroy(skel); }); + + if (bind_monitor_bpf__attach(skel) != 0) + { + LOG_ERROR("Failed to attach bind monitor BPF program, {}", errno); + return; + } + + auto onEvent = [](void* ctx, void* data, size_t dataSz) noexcept -> int { + try + { + static_cast(ctx)->OnBindEvent(data); + } + catch (...) + { + LOG_CAUGHT_EXCEPTION_MSG("Error processing bind monitor event"); + } + return 0; + }; + + auto* rb = ring_buffer__new(bpf_map__fd(skel->maps.events), onEvent, this, nullptr); + if (!rb) + { + LOG_ERROR("Failed to create bind monitor ring buffer, {}", errno); + return; + } + + auto destroyRb = wil::scope_exit([&] { ring_buffer__free(rb); }); + + GNS_LOG_INFO("BPF bind monitor for port-0 resolution attached and running"); + + for (;;) + { + int err = ring_buffer__poll(rb, -1); + if (err == -EINTR) + { + continue; + } + + if (err < 0) + { + LOG_ERROR("bind monitor ring_buffer__poll failed, {}", err); + return; + } + } + } + CATCH_LOG() + }}.detach(); auto future = std::make_optional(m_allocatedPortsRefresh.get_future()); std::optional refreshResult; @@ -117,19 +227,8 @@ void GnsPortTracker::Run() GNS_LOG_ERROR("Failed to complete bind request, {}", e.what()); } - if (bindCall->PortZeroBind.has_value()) - { - try - { - std::lock_guard lock(m_deferredMutex); - m_deferredQueue.push_back(std::move(bindCall->PortZeroBind.value())); - m_deferredCv.notify_one(); - } - catch (const std::exception& e) - { - GNS_LOG_ERROR("Failed to queue port-0 bind for deferred resolution, {}", e.what()); - } - } + // Port-0 binds are now handled by the eBPF callback (OnBindEvent), + // so no deferred resolution is needed here. } // If bindCall is empty, then the read() timed out. Look for any closed port @@ -139,46 +238,21 @@ void GnsPortTracker::Run() future.reset(); m_allocatedPortsRefresh = {}; - // If this loop's iteration had a bind call, it's possible that RefreshAllocatedPort - // was called before the bind called was processed. Make sure that the port list - // is up to date (If this is called, the next block will schedule another refresh) if (!bindCall.has_value()) { + std::lock_guard lock(m_portsMutex); OnRefreshAllocatedPorts(refreshResult->Ports, refreshResult->Timestamp); } } - // Process any port-0 binds that the background thread has resolved. - std::deque resolved; - { - std::lock_guard lock(m_resolvedMutex); - resolved.swap(m_resolvedQueue); - } - for (auto& allocation : resolved) - { - const auto result = HandleRequest(allocation); - if (result == 0) - { - TrackPort(std::move(allocation)); - } - else - { - GNS_LOG_ERROR( - "Failed to register resolved port-0 bind: family ({}) port ({}) protocol ({}), error {}", - allocation.Family, - allocation.Port, - allocation.Protocol, - result); - } - } - // Only look at bound ports if there's something to deallocate to avoid wasting cycles if (refreshResult.has_value()) { + std::lock_guard lock(m_portsMutex); if (!m_allocatedPorts.empty()) { future = m_allocatedPortsRefresh.get_future(); - refreshResult->Resume(); // This will resume the sock_diag thread + refreshResult->Resume(); refreshResult.reset(); } } @@ -244,7 +318,9 @@ std::set GnsPortTracker::ListAllocatedPorts() void GnsPortTracker::OnRefreshAllocatedPorts(const std::set& Ports, time_t Timestamp) { - // Because there's no way to get notified when the bind() call actually completes, it' possible + // m_portsMutex must be held by caller. + // + // Because there's no way to get notified when the bind() call actually completes, it's possible // that this method is called before the bind() completion and so the port allocation may not be visible yet. // To avoid deallocating ports that simply haven't been done allocating yet, m_allocatedPorts stores a timeout // that prevents deallocating the port unless: @@ -296,6 +372,7 @@ int GnsPortTracker::RequestPort(const PortAllocation& Port, bool allocate) static_assert(sizeof(request.Address32) == sizeof(Port.Address.s6_addr32)); memcpy(request.Address32, Port.Address.s6_addr32, sizeof(request.Address32)); + std::lock_guard lock(*m_channelMutex); const auto& response = m_hvSocketChannel->Transaction(request); return response.Result; @@ -303,14 +380,13 @@ int GnsPortTracker::RequestPort(const PortAllocation& Port, bool allocate) int GnsPortTracker::HandleRequest(const PortAllocation& Port) { - // If the port is already allocated, let the call go through and the kernel will - // decide if bind() should succeed or not - // Note: Returning 0 will also cause the port's timeout to be updated - - if (m_allocatedPorts.contains(Port)) { - GNS_LOG_INFO("Request for a port that's already reserved (family {}, port {}, protocol {})", Port.Family, Port.Port, Port.Protocol); - return 0; + std::lock_guard lock(m_portsMutex); + if (m_allocatedPorts.contains(Port)) + { + GNS_LOG_INFO("Request for a port that's already reserved (family {}, port {}, protocol {})", Port.Family, Port.Port, Port.Protocol); + return 0; + } } // Ask the host for this port otherwise @@ -322,7 +398,6 @@ int GnsPortTracker::HandleRequest(const PortAllocation& Port) std::optional GnsPortTracker::ReadNextRequest() { - // Read the call information auto request_value = m_request.try_get(c_bpf_poll_timeout); if (!request_value.has_value()) { @@ -331,11 +406,6 @@ std::optional GnsPortTracker::ReadNextRequest() auto callInfo = request_value.value(); - // This logic needs to be defensive because the calling process is blocked until - // CompleteRequest() is called, so if the call information can't be processed because - // the caller has done something wrong (bad pointer, fd, or protocol), just let it go through - // and the kernel will fail it - try { return GetCallInfo(callInfo.id, callInfo.pid, callInfo.data.arch, callInfo.data.nr, gsl::make_span(callInfo.data.args)); @@ -343,7 +413,7 @@ std::optional GnsPortTracker::ReadNextRequest() catch (const std::exception& e) { GNS_LOG_ERROR("Failed to read bind() call info with ID {} for pid {}, {}", callInfo.id, callInfo.pid, e.what()); - return {{{}, {}, callInfo.id}}; + return {{{}, false, callInfo.id}}; } } @@ -353,14 +423,14 @@ std::optional GnsPortTracker::GetCallInfo( auto ParseSocket = [&](int Socket, size_t AddressPtr, size_t AddressLength) -> std::optional { if (AddressLength < sizeof(sockaddr)) { - return {{{}, {}, CallId}}; // Invalid sockaddr. Let it go through. + return {{{}, false, CallId}}; } auto networkNamespace = std::filesystem::read_symlink(std::format("/proc/{}/ns/net", Pid)).string(); if (networkNamespace != m_networkNamespace) { GNS_LOG_INFO("Skipping bind() call for pid {} in network namespace {}", Pid, networkNamespace.c_str()); - return {{{}, {}, CallId}}; // Different network namespace. Let it go through. + return {{{}, false, CallId}}; } auto processMemory = m_seccompDispatcher->ReadProcessMemory(CallId, Pid, AddressPtr, AddressLength); @@ -374,40 +444,19 @@ std::optional GnsPortTracker::GetCallInfo( if ((address.sa_family != AF_INET && address.sa_family != AF_INET6) || (address.sa_family == AF_INET6 && AddressLength < sizeof(sockaddr_in6))) { - return {{{}, {}, CallId}}; // This is a non IP call, or invalid sockaddr_in6. Let it go through + return {{{}, false, CallId}}; } - // Read the port. The port *happens* to be in the same spot in memory for both sockaddr_in - // and sockaddr_in6. To avoid a second memory read, we take advantage of this fact to fetch - // the port from the currently read memory, regardless of the address family. static_assert(sizeof(sockaddr_in) <= sizeof(sockaddr)); const auto* inAddr = reinterpret_cast(&address); in_port_t port = ntohs(inAddr->sin_port); if (port == 0) { - // Port 0 means the kernel will assign an ephemeral port. We can't know - // the port until after the bind() completes, so duplicate the socket fd - // now (while the process is still stopped by seccomp) and defer the - // getsockname() lookup to after CompleteRequest() unblocks it. - try - { - const int protocol = GetSocketProtocol(Pid, Socket); - auto dupFd = DuplicateSocketFd(Pid, Socket); - if (!dupFd) - { - return {{{}, {}, CallId}}; - } - if (!m_seccompDispatcher->ValidateCookie(CallId)) - { - return {{{}, {}, CallId}}; - } - return {{{}, DeferredPortLookup{Pid, std::move(dupFd), protocol}, CallId}}; - } - catch (const std::exception&) - { - return {{{}, {}, CallId}}; // Can't determine protocol, just let it through - } + // Port 0 means the kernel will assign an ephemeral port. + // The eBPF fexit handler (OnBindEvent) will resolve the actual port + // after bind() completes. Just let the call through. + return {{{}, true, CallId}}; } in6_addr storedAddress = {}; @@ -422,35 +471,26 @@ std::optional GnsPortTracker::GetCallInfo( memcpy(storedAddress.s6_addr32, inAddr6->sin6_addr.s6_addr32, sizeof(storedAddress.s6_addr32)); } - // It's possible that the calling process lied and passed a sockaddr that - // doesn't match the underlying socket family or a bad fd. If that's the case, - // then GetSocketProtocol() will throw const int protocol = GetSocketProtocol(Pid, Socket); - // As GetSocketProtocol interacts with /proc/, to avoid TOCTOU races we need to - // verify that call is still valid (call id is the same thing as cookie) if (!m_seccompDispatcher->ValidateCookie(CallId)) { throw RuntimeErrorWithSourceLocation(std::format("Invalid call id {}", CallId)); } - return {{{PortAllocation(port, address.sa_family, protocol, storedAddress)}, {}, CallId}}; + return {{{PortAllocation(port, address.sa_family, protocol, storedAddress)}, false, CallId}}; }; #ifdef __x86_64__ if (Arch & __AUDIT_ARCH_64BIT) { return ParseSocket(Arguments[0], Arguments[1], Arguments[2]); } - // Note: 32bit on x86_64 uses the __NR_socketcall with the first argument - // set to SYS_BIND to make bind system call and the second argument is - // a pointer to a block of memory containing the original arguments. else { if (Arguments[0] != SYS_BIND) { - return {{{}, {}, CallId}}; // Not a bind call, just let the call go through + return {{{}, false, CallId}}; } - // Grab the first 3 parameters auto processMemory = m_seccompDispatcher->ReadProcessMemory(CallId, Pid, Arguments[1], sizeof(uint32_t) * 3); if (!processMemory.has_value()) { @@ -506,32 +546,12 @@ int GnsPortTracker::GetSocketProtocol(int pid, int fd) throw RuntimeErrorWithSourceLocation(std::format("Unexpected IP socket protocol: {}", protocol)); } -wil::unique_fd GnsPortTracker::DuplicateSocketFd(pid_t Pid, int SocketFd) -{ - // Duplicate the socket fd from the target process into our address space. - // We cannot use open("/proc/pid/fd/N") for sockets because the symlink target - // (socket:[inode]) is not a valid filesystem path. Use pidfd_getfd() instead. - wil::unique_fd pidFd(static_cast(syscall(SYS_pidfd_open, Pid, 0u))); - if (!pidFd) - { - GNS_LOG_INFO("Port-0 bind: pidfd_open failed for pid {} (errno {})", Pid, errno); - return {}; - } - - wil::unique_fd dupFd(static_cast(syscall(SYS_pidfd_getfd, pidFd.get(), SocketFd, 0u))); - if (!dupFd) - { - GNS_LOG_INFO("Port-0 bind: pidfd_getfd failed for pid {} fd {} (errno {})", Pid, SocketFd, errno); - } - - return dupFd; -} - void GnsPortTracker::TrackPort(PortAllocation allocation) try { // Use insert_or_assign so the deallocation timeout is refreshed if the same // port key is already present (emplace would silently keep the old entry). + std::lock_guard lock(m_portsMutex); m_allocatedPorts.insert_or_assign(std::move(allocation), std::make_optional(time(nullptr) + c_bind_timeout_seconds)); } catch (const std::exception& e) @@ -539,104 +559,6 @@ catch (const std::exception& e) GNS_LOG_ERROR("Failed to track port allocation, {}", e.what()); } -void GnsPortTracker::RunDeferredResolve() -{ - UtilSetThreadName("GnsPortZero"); - - for (;;) - { - DeferredPortLookup lookup{0, {}, 0}; - { - std::unique_lock lock(m_deferredMutex); - m_deferredCv.wait(lock, [&] { return !m_deferredQueue.empty(); }); - lookup = std::move(m_deferredQueue.front()); - m_deferredQueue.pop_front(); - } - - const auto pid = lookup.Pid; - try - { - ResolvePortZeroBind(std::move(lookup)); - } - catch (const std::exception& e) - { - GNS_LOG_ERROR("Failed to resolve port-0 bind for pid {}, {}", pid, e.what()); - } - } -} - -void GnsPortTracker::ResolvePortZeroBind(DeferredPortLookup lookup) -{ - // The socket fd was already duplicated (via pidfd_getfd) while the target process - // was stopped by seccomp, so it remains valid even if the process has closed or - // reused the original fd number. - - // The bind() syscall is being completed asynchronously on the seccomp dispatcher - // thread after CompleteRequest() unblocks it. Poll getsockname() briefly until - // the kernel assigns a port. - constexpr int maxRetries = 25; - constexpr auto retryDelay = std::chrono::milliseconds(100); - - in_port_t port = 0; - in6_addr address = {}; - int resolvedFamily = 0; - - for (int attempt = 0; attempt < maxRetries; ++attempt) - { - if (attempt > 0) - { - std::this_thread::sleep_for(retryDelay); - } - - sockaddr_storage storage{}; - socklen_t addrLen = sizeof(storage); - if (getsockname(lookup.DuplicatedSocketFd.get(), reinterpret_cast(&storage), &addrLen) != 0) - { - GNS_LOG_ERROR("Port-0 bind: getsockname failed for pid {} (errno {})", lookup.Pid, errno); - return; - } - - resolvedFamily = static_cast(storage.ss_family); - - if (storage.ss_family == AF_INET) - { - const auto* sin = reinterpret_cast(&storage); - port = ntohs(sin->sin_port); - address.s6_addr32[0] = sin->sin_addr.s_addr; - } - else if (storage.ss_family == AF_INET6) - { - const auto* sin6 = reinterpret_cast(&storage); - port = ntohs(sin6->sin6_port); - memcpy(address.s6_addr32, sin6->sin6_addr.s6_addr32, sizeof(address.s6_addr32)); - } - else - { - GNS_LOG_ERROR("Port-0 bind: unexpected address family ({}) for pid {}", resolvedFamily, lookup.Pid); - return; - } - - if (port != 0) - { - break; - } - } - - if (port == 0) - { - GNS_LOG_ERROR("Port-0 bind: kernel did not assign a port for pid {} after retries", lookup.Pid); - return; - } - - PortAllocation allocation(port, resolvedFamily, lookup.Protocol, address); - GNS_LOG_INFO( - "Port-0 bind resolved: family ({}) port ({}) protocol ({}) for pid {}", resolvedFamily, port, lookup.Protocol, lookup.Pid); - { - std::lock_guard lock(m_resolvedMutex); - m_resolvedQueue.push_back(std::move(allocation)); - } -} - std::ostream& operator<<(std::ostream& out, const GnsPortTracker::PortAllocation& entry) { return out << "Port=" << entry.Port << ", Family=" << entry.Family << ", Protocol=" << entry.Protocol; diff --git a/src/linux/init/GnsPortTracker.h b/src/linux/init/GnsPortTracker.h index a3f2945a1..2dfa2a8dd 100644 --- a/src/linux/init/GnsPortTracker.h +++ b/src/linux/init/GnsPortTracker.h @@ -21,17 +21,22 @@ class GnsPortTracker { public: - GnsPortTracker(std::shared_ptr hvSocketChannel, NetlinkChannel&& netlinkChannel, std::shared_ptr seccompDispatcher); + GnsPortTracker( + std::shared_ptr hvSocketChannel, + NetlinkChannel&& netlinkChannel, + std::shared_ptr seccompDispatcher, + std::shared_ptr channelMutex); - GnsPortTracker(const GnsPortTracker&) = delete; - GnsPortTracker(GnsPortTracker&&) = delete; - GnsPortTracker& operator=(const GnsPortTracker&) = delete; - GnsPortTracker& operator=(GnsPortTracker&&) = delete; + NON_COPYABLE(GnsPortTracker); + NON_MOVABLE(GnsPortTracker); void Run(); int ProcessSecCompNotification(seccomp_notif* notification); + // Called from the eBPF ring buffer callback thread to handle port-0 bind resolution. + void OnBindEvent(void* data); + struct PortAllocation { in6_addr Address = {}; @@ -94,27 +99,10 @@ class GnsPortTracker } }; - struct DeferredPortLookup - { - pid_t Pid; - wil::unique_fd DuplicatedSocketFd; // Duplicated via pidfd_getfd while process was stopped - int Protocol; - - DeferredPortLookup(pid_t Pid, wil::unique_fd DuplicatedSocketFd, int Protocol) : - Pid(Pid), DuplicatedSocketFd(std::move(DuplicatedSocketFd)), Protocol(Protocol) - { - } - - DeferredPortLookup(DeferredPortLookup&&) = default; - DeferredPortLookup& operator=(DeferredPortLookup&&) = default; - DeferredPortLookup(const DeferredPortLookup&) = delete; - DeferredPortLookup& operator=(const DeferredPortLookup&) = delete; - }; - struct BindCall { std::optional Request; - std::optional PortZeroBind; + bool IsPortZeroBind = false; std::uint64_t CallId; }; @@ -144,16 +132,14 @@ class GnsPortTracker static int GetSocketProtocol(int Pid, int Fd); - static wil::unique_fd DuplicateSocketFd(pid_t Pid, int SocketFd); - - void ResolvePortZeroBind(DeferredPortLookup lookup); - - void RunDeferredResolve(); - void TrackPort(PortAllocation allocation); + // Protects m_allocatedPorts for concurrent access from the main loop and the eBPF callback thread. + std::mutex m_portsMutex; std::map> m_allocatedPorts; + std::shared_ptr m_hvSocketChannel; + std::shared_ptr m_channelMutex; NetlinkChannel m_channel; std::promise m_allocatedPortsRefresh; @@ -163,15 +149,6 @@ class GnsPortTracker std::shared_ptr m_seccompDispatcher; std::string m_networkNamespace; - - std::mutex m_deferredMutex; - std::condition_variable m_deferredCv; - std::deque m_deferredQueue; - - // Resolved port-0 allocations posted by the background RunDeferredResolve thread - // for the main Run() loop to process (keeps SocketChannel access single-threaded). - std::mutex m_resolvedMutex; - std::deque m_resolvedQueue; }; std::ostream& operator<<(std::ostream& out, const GnsPortTracker::PortAllocation& portAllocation); diff --git a/src/linux/init/localhost.cpp b/src/linux/init/localhost.cpp index f1af32182..3d7e3c3db 100644 --- a/src/linux/init/localhost.cpp +++ b/src/linux/init/localhost.cpp @@ -1,6 +1,7 @@ // Copyright (C) Microsoft Corporation. All rights reserved. #include "common.h" #include +#include #include #include #include @@ -12,11 +13,12 @@ #include #include #include -#include -#include #include +#include #include +#include + #include "util.h" #include "SocketChannel.h" #include "GnsPortTracker.h" @@ -25,8 +27,8 @@ #include "CommandLine.h" #include "NetlinkChannel.h" #include "NetlinkTransactionError.h" - -#define TCP_LISTEN 10 +#include "listen_monitor.h" +#include "listen_monitor.skel.h" namespace { @@ -149,64 +151,6 @@ void ListenThread(sockaddr_vm hvSocketAddress, int listenSocket) return; } -std::vector QueryListeningSockets(NetlinkChannel& channel) -{ - std::vector sockets{}; - try - { - inet_diag_req_v2 message{}; - message.sdiag_protocol = IPPROTO_TCP; - message.idiag_states = (1 << TCP_LISTEN); - - auto onMessage = [&](const NetlinkResponse& response) { - for (const auto& e : response.Messages(SOCK_DIAG_BY_FAMILY)) - { - const auto* payload = e.Payload(); - sockaddr_storage sock{}; - - if (payload->idiag_family == AF_INET) - { - auto* ipv4 = reinterpret_cast(&sock); - ipv4->sin_family = AF_INET; - ipv4->sin_addr.s_addr = payload->id.idiag_src[0]; - ipv4->sin_port = payload->id.idiag_sport; - } - else if (payload->idiag_family == AF_INET6) - { - auto* ipv6 = reinterpret_cast(&sock); - ipv6->sin6_family = AF_INET6; - static_assert(sizeof(ipv6->sin6_addr.s6_addr32) == sizeof(payload->id.idiag_src)); - memcpy(ipv6->sin6_addr.s6_addr32, payload->id.idiag_src, sizeof(ipv6->sin6_addr.s6_addr32)); - ipv6->sin6_port = payload->id.idiag_sport; - } - - sockets.emplace_back(sock); - } - }; - - // Query IPv4 listening sockets. - { - message.sdiag_family = AF_INET; - auto transaction = channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); - transaction.Execute(onMessage); - } - - // Query IPv6 listening sockets. - { - message.sdiag_family = AF_INET6; - auto transaction = channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP); - transaction.Execute(onMessage); - } - } - catch (const NetlinkTransactionError& e) - { - // Log but don't fail - network state might be temporarily unavailable - LOG_ERROR("Failed to query listening sockets via sock_diag: {}", e.what()); - } - - return sockets; -} - int SendRelayListenerSocket(wsl::shared::SocketChannel& channel, int hvSocketPort) try { @@ -263,87 +207,91 @@ try } CATCH_RETURN_ERRNO(); -bool IsSameSockAddr(const sockaddr_storage& left, const sockaddr_storage& right) +extern "C" int OnListenMonitorEvent(void* ctx, void* data, size_t dataSz) noexcept +try { - if (left.ss_family != right.ss_family) + auto* channel = static_cast(ctx); + const auto* event = static_cast(data); + + sockaddr_storage sock{}; + if (event->family == AF_INET) + { + auto* ipv4 = reinterpret_cast(&sock); + ipv4->sin_family = AF_INET; + ipv4->sin_addr.s_addr = event->addr4; + ipv4->sin_port = htons(event->port); + } + else if (event->family == AF_INET6) { - return false; + auto* ipv6 = reinterpret_cast(&sock); + ipv6->sin6_family = AF_INET6; + memcpy(ipv6->sin6_addr.s6_addr, event->addr6, sizeof(ipv6->sin6_addr.s6_addr)); + ipv6->sin6_port = htons(event->port); + } + else + { + return 0; } - if (left.ss_family == AF_INET) + if (event->is_listen) { - auto leftIpv4 = reinterpret_cast(&left); - auto rightIpv4 = reinterpret_cast(&right); - return (leftIpv4->sin_addr.s_addr == rightIpv4->sin_addr.s_addr && leftIpv4->sin_port == rightIpv4->sin_port); + StartHostListener(*channel, sock); } - else if (left.ss_family == AF_INET6) + else { - auto leftIpv6 = reinterpret_cast(&left); - auto rightIpv6 = reinterpret_cast(&right); - return (leftIpv6->sin6_port == rightIpv6->sin6_port && memcmp(&leftIpv6->sin6_addr, &rightIpv6->sin6_addr, sizeof(in6_addr)) == 0); + StopHostListener(*channel, sock); } - FATAL_ERROR("Unrecognized socket family {}", left.ss_family); - return false; + return 0; +} +catch (...) +{ + LOG_CAUGHT_EXCEPTION_MSG("Error processing listen monitor event"); + return 0; } -// Monitor listening TCP sockets using sock_diag netlink interface. int MonitorListeningSockets(wsl::shared::SocketChannel& channel) { - NetlinkChannel netlinkChannel(SOCK_RAW, NETLINK_SOCK_DIAG); - std::vector relays{}; - int result = 0; - - for (;;) + auto* skel = listen_monitor_bpf__open_and_load(); + if (!skel) { - auto sockets = QueryListeningSockets(netlinkChannel); + LOG_ERROR("Failed to open/load listen monitor BPF program, {}", errno); + return -1; + } - // Stop any relays that no longer match listening ports. - std::erase_if(relays, [&](const auto& entry) { - auto found = - std::find_if(sockets.begin(), sockets.end(), [&](const auto& socket) { return IsSameSockAddr(entry, socket); }); + auto destroySkel = wil::scope_exit([&] { listen_monitor_bpf__destroy(skel); }); - bool remove = (found == sockets.end()); - if (remove) - { - if (StopHostListener(channel, entry) < 0) - { - result = -1; - } - } + if (listen_monitor_bpf__attach(skel) != 0) + { + LOG_ERROR("Failed to attach listen monitor BPF program, {}", errno); + return -1; + } - return remove; - }); + auto* rb = ring_buffer__new(bpf_map__fd(skel->maps.events), OnListenMonitorEvent, &channel, nullptr); + if (!rb) + { + LOG_ERROR("Failed to create ring buffer, {}", errno); + return -1; + } - // Create relays for any new ports. - std::for_each(sockets.begin(), sockets.end(), [&](const auto& socket) { - auto found = - std::find_if(relays.begin(), relays.end(), [&](const auto& entry) { return IsSameSockAddr(entry, socket); }); + auto destroyRb = wil::scope_exit([&] { ring_buffer__free(rb); }); - if (found == relays.end()) - { - if (StartHostListener(channel, socket) < 0) - { - result = -1; - } - else - { - relays.push_back(socket); - } - } - }); + GNS_LOG_INFO("BPF listen monitor attached and running"); - // Ensure all start / stop operations were successful. - if (result < 0) + for (;;) + { + int err = ring_buffer__poll(rb, -1); + if (err == -EINTR) { - break; + continue; } - // Sleep before scanning again. - std::this_thread::sleep_for(std::chrono::seconds(1)); + if (err < 0) + { + LOG_ERROR("ring_buffer__poll failed, {}", err); + return -1; + } } - - return result; } } // namespace @@ -464,9 +412,11 @@ int RunPortTracker(int Argc, char** Argv) auto channel = NetlinkChannel::FromFd(NetlinkSocketFd); + auto channelMutex = std::make_shared(); + auto seccompDispatcher = std::make_shared(BpfFd); - GnsPortTracker portTracker(hvSocketChannel, std::move(channel), seccompDispatcher); + GnsPortTracker portTracker(hvSocketChannel, std::move(channel), seccompDispatcher, channelMutex); seccompDispatcher->RegisterHandler( __NR_bind, [&portTracker](seccomp_notif* notification) { return portTracker.ProcessSecCompNotification(notification); }); @@ -481,7 +431,7 @@ int RunPortTracker(int Argc, char** Argv) }); #endif - seccompDispatcher->RegisterHandler(__NR_ioctl, [hvSocketChannel, seccompDispatcher](auto notification) -> int { + seccompDispatcher->RegisterHandler(__NR_ioctl, [hvSocketChannel, seccompDispatcher, channelMutex](auto notification) -> int { LX_GNS_TUN_BRIDGE_REQUEST request{}; request.Header.MessageType = LxGnsMessageIfStateChangeRequest; request.Header.MessageSize = sizeof(request); @@ -495,6 +445,8 @@ int RunPortTracker(int Argc, char** Argv) auto& ifRequest = *reinterpret_cast(ifreqMemory->data()); memcpy(request.InterfaceName, ifRequest.ifr_ifrn.ifrn_name, sizeof(request.InterfaceName)); request.InterfaceUp = ifRequest.ifr_ifru.ifru_flags & IFF_UP; + + std::lock_guard lock(*channelMutex); const auto& reply = hvSocketChannel->Transaction(request); return reply.Result; diff --git a/src/linux/init/main.cpp b/src/linux/init/main.cpp index f31ebf732..c5c8f4d9a 100644 --- a/src/linux/init/main.cpp +++ b/src/linux/init/main.cpp @@ -3540,7 +3540,7 @@ wil::unique_fd RegisterSeccompHook() Routine Description: - Register a seccomp notification for bind() & ioctl(*, TUNSETIFF, *) calls. + Register a seccomp notification for bind() & ioctl(*, SIOCSIFFLAGS, *) calls. Arguments: @@ -3568,7 +3568,7 @@ Return Value: // If syscall_nr == __NR_bind then goto user_notify: else continue BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_nr), BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, __NR_bind, 3, 0), - // if (syscall_nr == __NR_bind) then continue else goto allow: + // if (syscall_nr == __NR_ioctl) then continue else goto allow: BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, __NR_ioctl, 0, 3), // if (syscall arg1 == SIOCSIFFLAGS) goto user_notify else goto allow: BPF_STMT(BPF_LD + BPF_W + BPF_ABS, syscall_arg(1)),