diff --git a/src/windows/common/VirtioNetworking.cpp b/src/windows/common/VirtioNetworking.cpp index cc6d794eb..7f2d9439d 100644 --- a/src/windows/common/VirtioNetworking.cpp +++ b/src/windows/common/VirtioNetworking.cpp @@ -101,7 +101,6 @@ HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int return S_OK; } - int result = 0; const auto ipAddress = (addr.si_family == AF_INET) ? reinterpret_cast(&addr.Ipv4.sin_addr) : reinterpret_cast(&addr.Ipv6.sin6_addr); const bool loopback = INET_IS_ADDR_LOOPBACK(addr.si_family, ipAddress); @@ -111,10 +110,12 @@ HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int // Only intercepting 127.0.0.1; any other loopback address will remain on 'lo'. if (addr.Ipv4.sin_addr.s_addr != htonl(INADDR_LOOPBACK)) { - return result; + return S_OK; } } + const auto guestPort = INETADDR_PORT(reinterpret_cast(&addr)); + if (WI_IsFlagSet(m_flags, VirtioNetworkingFlags::LocalhostRelay) && (unspecified || loopback)) { SOCKADDR_INET localAddr = addr; @@ -130,57 +131,102 @@ HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int localAddr.Ipv6.sin6_port = addr.Ipv6.sin6_port; } } - result = ModifyOpenPorts(c_loopbackDeviceName, localAddr, protocol, allocate); - LOG_HR_IF_MSG( - E_FAIL, result != S_OK, "Failure adding localhost relay port %d", INETADDR_PORT(reinterpret_cast(&localAddr))); + + try + { + const auto addrStr = wsl::windows::common::string::SockAddrInetToString(localAddr); + ModifyOpenPorts(c_loopbackDeviceName, addrStr.c_str(), guestPort, guestPort, protocol, allocate); + } + catch (...) + { + LOG_CAUGHT_EXCEPTION_MSG("Failure adding localhost relay port %d", guestPort); + } } if (!loopback) { - const int localResult = ModifyOpenPorts(c_eth0DeviceName, addr, protocol, allocate); - LOG_HR_IF_MSG(E_FAIL, localResult != S_OK, "Failure adding relay port %d", INETADDR_PORT(reinterpret_cast(&addr))); - if (result == 0) + try + { + const auto addrStr = wsl::windows::common::string::SockAddrInetToString(addr); + ModifyOpenPorts(c_eth0DeviceName, addrStr.c_str(), guestPort, guestPort, protocol, allocate); + } + catch (...) { - result = localResult; + LOG_CAUGHT_EXCEPTION_MSG("Failure adding relay port %d", guestPort); } } - return result; + return S_OK; } -int VirtioNetworking::ModifyOpenPorts(_In_ PCWSTR tag, _In_ const SOCKADDR_INET& addr, _In_ int protocol, _In_ bool isOpen) const +uint16_t VirtioNetworking::ModifyOpenPorts( + _In_ PCWSTR tag, _In_opt_ PCSTR hostAddress, _In_ uint16_t HostPort, _In_ uint16_t GuestPort, _In_ int protocol, _In_ bool isOpen) const { - if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP) - { - LOG_HR_MSG(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), "Unsupported bind protocol %d", protocol); - return 0; - } + THROW_HR_IF_MSG( + HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), + protocol != IPPROTO_TCP && protocol != IPPROTO_UDP, + "Unsupported bind protocol %d", + protocol); auto lock = m_lock.lock_exclusive(); const auto server = m_guestDeviceManager->GetRemoteFileSystem(VIRTIO_NET_CLASS_ID, c_defaultDeviceTag); - if (server) + THROW_HR_IF(E_NOT_SET, !server); + + // format: tag={tag}[;host_port={port}];guest_port={port}[;listen_addr={addr}|;allocate=false][;udp] + std::wstring portString = std::format(L"tag={};guest_port={};listen_addr={}", tag, GuestPort, hostAddress); + + if (HostPort != WSLC_EPHEMERAL_PORT) { - std::wstring portString = std::format(L"tag={};port_number={}", tag, INETADDR_PORT(reinterpret_cast(&addr))); - if (protocol == IPPROTO_UDP) - { - portString += L";udp"; - } + portString += std::format(L";host_port={}", HostPort); + } - if (!isOpen) - { - portString += L";allocate=false"; - } - else - { - const auto addrStr = wsl::windows::common::string::SockAddrInetToWstring(addr); - portString += std::format(L";listen_addr={}", addrStr); - } + if (!isOpen) + { + portString += L";allocate=false"; + } + + if (protocol == IPPROTO_UDP) + { + portString += L";udp"; + } + + const HRESULT addShareResult = server->AddShare(portString.c_str(), nullptr, 0); - LOG_IF_FAILED(server->AddShare(portString.c_str(), nullptr, 0)); + if (HostPort == WSLC_EPHEMERAL_PORT && isOpen && SUCCEEDED(addShareResult)) + { + // For anonymous binds, the allocated host port is encoded in the return value. + return static_cast(addShareResult - S_OK); } - return 0; + THROW_IF_FAILED_MSG(addShareResult, "Failed to set virtionet port mapping: %ls", portString.c_str()); + return HostPort; +} + +HRESULT VirtioNetworking::MapPort(_In_ USHORT HostPort, _In_ USHORT GuestPort, _In_ int Protocol, _In_ PCSTR ListenAddress, _Out_ USHORT* AllocatedHostPort) const +try +{ + RETURN_HR_IF(E_POINTER, AllocatedHostPort == nullptr || ListenAddress == nullptr); + RETURN_HR_IF_MSG(E_INVALIDARG, Protocol != IPPROTO_TCP && Protocol != IPPROTO_UDP, "Invalid protocol: %i", Protocol); + + *AllocatedHostPort = 0; + + *AllocatedHostPort = ModifyOpenPorts(c_eth0DeviceName, ListenAddress, HostPort, GuestPort, Protocol, true); + return S_OK; +} +CATCH_RETURN() + +HRESULT VirtioNetworking::UnmapPort(_In_ USHORT HostPort, _In_ USHORT GuestPort, _In_ int Protocol, _In_ PCSTR ListenAddress) const +try +{ + RETURN_HR_IF(E_POINTER, ListenAddress == nullptr); + RETURN_HR_IF(E_INVALIDARG, Protocol != IPPROTO_TCP && Protocol != IPPROTO_UDP); + + const auto listenAddrW = wsl::shared::string::MultiByteToWide(ListenAddress); + + ModifyOpenPorts(c_eth0DeviceName, nullptr, HostPort, GuestPort, Protocol, false); + return S_OK; } +CATCH_RETURN() void VirtioNetworking::RefreshGuestConnection() { diff --git a/src/windows/common/VirtioNetworking.h b/src/windows/common/VirtioNetworking.h index 5a95e6016..2804de9ef 100644 --- a/src/windows/common/VirtioNetworking.h +++ b/src/windows/common/VirtioNetworking.h @@ -46,11 +46,15 @@ class VirtioNetworking : public INetworkingEngine void FillInitialConfiguration(LX_MINI_INIT_NETWORKING_CONFIGURATION& message) override; void StartPortTracker(wil::unique_socket&& socket) override; + HRESULT MapPort(_In_ USHORT HostPort, _In_ USHORT GuestPort, _In_ int Protocol, _In_ PCSTR ListenAddress, _Out_ USHORT* AllocatedHostPort) const; + + HRESULT UnmapPort(_In_ USHORT HostPort, _In_ USHORT GuestPort, _In_ int Protocol, _In_ PCSTR ListenAddress) const; + private: static void NETIOAPI_API_ OnNetworkConnectivityChange(PVOID context, NL_NETWORK_CONNECTIVITY_HINT hint); HRESULT HandlePortNotification(const SOCKADDR_INET& addr, int protocol, bool allocate) const noexcept; - int ModifyOpenPorts(_In_ PCWSTR tag, _In_ const SOCKADDR_INET& addr, _In_ int protocol, _In_ bool isOpen) const; + uint16_t ModifyOpenPorts(_In_ PCWSTR tag, _In_opt_ PCSTR hostAddress, _In_ uint16_t HostPort, _In_ uint16_t GuestPort, _In_ int protocol, _In_ bool isOpen) const; void RefreshGuestConnection(); void SetupLoopbackDevice(); void SendDefaultRoute(const std::wstring& gateway, wsl::shared::hns::ModifyRequestType requestType); diff --git a/src/windows/service/exe/HcsVirtualMachine.cpp b/src/windows/service/exe/HcsVirtualMachine.cpp index b554128f0..a26dc636a 100644 --- a/src/windows/service/exe/HcsVirtualMachine.cpp +++ b/src/windows/service/exe/HcsVirtualMachine.cpp @@ -581,6 +581,36 @@ try } CATCH_RETURN() +HRESULT HcsVirtualMachine::MapVirtioNetPort(_In_ USHORT HostPort, _In_ USHORT GuestPort, _In_ int Protocol, _In_ LPCSTR ListenAddress, _Out_ USHORT* AllocatedHostPort) +try +{ + RETURN_HR_IF(E_POINTER, AllocatedHostPort == nullptr || ListenAddress == nullptr); + + *AllocatedHostPort = 0; + + std::lock_guard lock(m_lock); + + auto* virtioNet = dynamic_cast(m_networkEngine.get()); + RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), virtioNet == nullptr); + + return virtioNet->MapPort(HostPort, GuestPort, Protocol, ListenAddress, AllocatedHostPort); +} +CATCH_RETURN() + +HRESULT HcsVirtualMachine::UnmapVirtioNetPort(_In_ USHORT HostPort, _In_ USHORT GuestPort, _In_ int Protocol, _In_ LPCSTR ListenAddress) +try +{ + RETURN_HR_IF(E_POINTER, ListenAddress == nullptr); + + std::lock_guard lock(m_lock); + + auto* virtioNet = dynamic_cast(m_networkEngine.get()); + RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), virtioNet == nullptr); + + return virtioNet->UnmapPort(HostPort, GuestPort, Protocol, ListenAddress); +} +CATCH_RETURN() + void CALLBACK HcsVirtualMachine::OnVmExitCallback(HCS_EVENT* Event, void* Context) try { diff --git a/src/windows/service/exe/HcsVirtualMachine.h b/src/windows/service/exe/HcsVirtualMachine.h index 5f6cdf36d..fbc781644 100644 --- a/src/windows/service/exe/HcsVirtualMachine.h +++ b/src/windows/service/exe/HcsVirtualMachine.h @@ -43,6 +43,10 @@ class HcsVirtualMachine IFACEMETHOD(DetachDisk)(_In_ ULONG Lun) override; IFACEMETHOD(AddShare)(_In_ LPCWSTR WindowsPath, _In_ BOOL ReadOnly, _Out_ GUID* ShareId) override; IFACEMETHOD(RemoveShare)(_In_ REFGUID ShareId) override; + IFACEMETHOD(MapVirtioNetPort) + (_In_ USHORT HostPort, _In_ USHORT GuestPort, _In_ int Protocol, _In_ LPCSTR ListenAddress, _Out_ USHORT* AllocatedHostPort) override; + IFACEMETHOD(UnmapVirtioNetPort) + (_In_ USHORT HostPort, _In_ USHORT GuestPort, _In_ int Protocol, _In_ LPCSTR ListenAddress) override; private: struct DiskInfo diff --git a/src/windows/service/inc/wslc.idl b/src/windows/service/inc/wslc.idl index dce02dbae..ac1d64dc9 100644 --- a/src/windows/service/inc/wslc.idl +++ b/src/windows/service/inc/wslc.idl @@ -25,12 +25,14 @@ cpp_quote("#endif") #define WSLC_MAX_VOLUME_NAME_LENGTH 255 #define WSLC_CONTAINER_ID_LENGTH 64 #define WSLC_MAX_BINDING_ADDRESS_LENGTH 45 +#define WSLC_EPHEMERAL_PORT 0 cpp_quote("#define WSLC_MAX_CONTAINER_NAME_LENGTH 255") cpp_quote("#define WSLC_MAX_IMAGE_NAME_LENGTH 255") cpp_quote("#define WSLC_MAX_VOLUME_NAME_LENGTH 255") cpp_quote("#define WSLC_CONTAINER_ID_LENGTH 64") cpp_quote("#define WSLC_MAX_BINDING_ADDRESS_LENGTH 45") +cpp_quote("#define WSLC_EPHEMERAL_PORT 0") typedef struct _WSLCVersion { @@ -425,6 +427,25 @@ interface IWSLCVirtualMachine : IUnknown // Removes a previously added filesystem share. HRESULT RemoveShare([in] REFGUID ShareId); + + // Maps a port via VirtioNetworking. + // For anonymous binds (HostPort == WSLC_EPHEMERAL_PORT), the networking engine allocates a host port + // and returns it in AllocatedHostPort. + // Protocol must be IPPROTO_TCP or IPPROTO_UDP. + // ListenAddress is the IP address to bind on (e.g. "127.0.0.1", "0.0.0.0", "::1"). + HRESULT MapVirtioNetPort( + [in] USHORT HostPort, + [in] USHORT GuestPort, + [in] int Protocol, + [in] LPCSTR ListenAddress, + [out, retval] USHORT* AllocatedHostPort); + + // Unmaps a port previously mapped via MapVirtioNetPort. + HRESULT UnmapVirtioNetPort( + [in] USHORT HostPort, + [in] USHORT GuestPort, + [in] int Protocol, + [in] LPCSTR ListenAddress); } typedef enum _WSLCSessionStorageFlags diff --git a/src/windows/wslc/services/ContainerService.cpp b/src/windows/wslc/services/ContainerService.cpp index 05a5cc295..1ba743cf1 100644 --- a/src/windows/wslc/services/ContainerService.cpp +++ b/src/windows/wslc/services/ContainerService.cpp @@ -54,27 +54,21 @@ static wsl::windows::common::RunningWSLCContainer CreateInternal(Session& sessio { auto portMapping = PublishPort::Parse(port); + const int protocol = portMapping.PortProtocol() == PublishPort::Protocol::UDP ? IPPROTO_UDP : IPPROTO_TCP; + const int family = (portMapping.HostIP().has_value() && portMapping.HostIP()->IsIPv6()) ? AF_INET6 : AF_INET; + std::optional bindAddress; + if (portMapping.HostIP().has_value()) { - // https://github.com/microsoft/WSL/issues/14433 - // The following scenarios are currently not implemented: - // - Ephemeral host port mappings - // - Host port mappings with a specific host IP - // - Host port mappings with UDP protocol - if (portMapping.HostPort().IsEphemeral() || portMapping.HostIP().has_value() || - portMapping.PortProtocol() == PublishPort::Protocol::UDP) - { - THROW_HR_WITH_USER_ERROR( - HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), - "Port mappings with ephemeral host ports, specific host IPs, or UDP protocol are not currently supported"); - } + bindAddress = portMapping.HostIP()->IP(); } auto containerPort = portMapping.ContainerPort(); for (uint16_t i = 0; i < containerPort.Count(); ++i) { auto currentContainerPort = static_cast(containerPort.Start() + i); - auto currentHostPort = static_cast(portMapping.HostPort().Start() + i); - containerLauncher.AddPort(currentHostPort, currentContainerPort, AF_INET); + auto currentHostPort = portMapping.HostPort().IsEphemeral() ? static_cast(WSLC_EPHEMERAL_PORT) + : static_cast(portMapping.HostPort().Start() + i); + containerLauncher.AddPort(currentHostPort, currentContainerPort, family, protocol, bindAddress); } } diff --git a/src/windows/wslcsession/WSLCVirtualMachine.cpp b/src/windows/wslcsession/WSLCVirtualMachine.cpp index 734d0d049..b7957e00c 100644 --- a/src/windows/wslcsession/WSLCVirtualMachine.cpp +++ b/src/windows/wslcsession/WSLCVirtualMachine.cpp @@ -180,6 +180,19 @@ uint16_t VMPortMapping::HostPort() const } } +void VMPortMapping::SetHostPort(uint16_t port) +{ + if (BindAddress.si_family == AF_INET6) + { + BindAddress.Ipv6.sin6_port = htons(port); + } + else + { + WI_ASSERT(BindAddress.si_family == AF_INET); + BindAddress.Ipv4.sin_port = htons(port); + } +} + std::string VMPortMapping::BindingAddressString() const { char buffer[INET6_ADDRSTRLEN]{}; @@ -890,15 +903,15 @@ void WSLCVirtualMachine::MapPort(VMPortMapping& Mapping) } else if (m_networkingMode == WSLCNetworkingModeVirtioProxy) { - // TODO: Switch to using the native virtionet relay. - THROW_HR_IF_MSG( - HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), - !Mapping.IsLocalhost() || Mapping.Protocol != IPPROTO_TCP, - "Unsupported port mapping for virtionet mode: %hs, protocol: %i", - Mapping.BindingAddressString().c_str(), - Mapping.Protocol); + USHORT allocatedHostPort = 0; + THROW_IF_FAILED(m_vm->MapVirtioNetPort( + Mapping.HostPort(), Mapping.VmPort->Port(), Mapping.Protocol, Mapping.BindingAddressString().c_str(), &allocatedHostPort)); - MapRelayPort(Mapping.BindAddress.si_family, Mapping.HostPort(), Mapping.VmPort->Port(), false); + // For anonymous binds, write back the allocated host port. + if (Mapping.HostPort() == WSLC_EPHEMERAL_PORT && allocatedHostPort != 0) + { + Mapping.SetHostPort(allocatedHostPort); + } } else { @@ -922,8 +935,8 @@ void WSLCVirtualMachine::UnmapPort(VMPortMapping& Mapping) } else if (m_networkingMode == WSLCNetworkingModeVirtioProxy) { - // TODO: Switch to using the native virtionet relay. - MapRelayPort(Mapping.BindAddress.si_family, Mapping.HostPort(), Mapping.VmPort->Port(), true); + THROW_IF_FAILED(m_vm->UnmapVirtioNetPort( + Mapping.HostPort(), Mapping.VmPort->Port(), Mapping.Protocol, Mapping.BindingAddressString().c_str())); } else { diff --git a/src/windows/wslcsession/WSLCVirtualMachine.h b/src/windows/wslcsession/WSLCVirtualMachine.h index 99cd515ec..d3fa66e4f 100644 --- a/src/windows/wslcsession/WSLCVirtualMachine.h +++ b/src/windows/wslcsession/WSLCVirtualMachine.h @@ -90,6 +90,7 @@ struct VMPortMapping void Attach(WSLCVirtualMachine& Vm); void Detach(); uint16_t HostPort() const; + void SetHostPort(uint16_t port); static VMPortMapping LocalhostTcpMapping(int Family, uint16_t WindowsPort); static VMPortMapping FromWSLCPortMapping(const ::WSLCPortMapping& Mapping);