From 8a66990cf5f2f1fb09105839a0f9e7412e22105d Mon Sep 17 00:00:00 2001 From: Steve Syfuhs Date: Fri, 12 Dec 2025 11:28:31 -0800 Subject: [PATCH] Put ping check behind configuration --- .../Client/Transport/ClientDomainService.cs | 80 +++++++++++- .../Transport/HttpsKerberosTransport.cs | 2 - .../Client/Transport/KerberosTransportBase.cs | 76 ++++------- .../Configuration/Krb5ConfigDefaults.cs | 7 + Kerberos.NET/Dns/DnsRecord.cs | 2 + Kerberos.NET/TaskExtensions.cs | 29 ++++- .../End2End/TcpClientServerTests.cs | 120 ++++++++++++++++++ 7 files changed, 258 insertions(+), 58 deletions(-) diff --git a/Kerberos.NET/Client/Transport/ClientDomainService.cs b/Kerberos.NET/Client/Transport/ClientDomainService.cs index 396413a7..8ba1b3a1 100644 --- a/Kerberos.NET/Client/Transport/ClientDomainService.cs +++ b/Kerberos.NET/Client/Transport/ClientDomainService.cs @@ -3,6 +3,8 @@ using System.Collections.Generic; using System.Globalization; using System.Linq; +using System.Net.NetworkInformation; +using System.Threading; using System.Threading.Tasks; using Kerberos.NET.Configuration; using Kerberos.NET.Dns; @@ -13,6 +15,8 @@ namespace Kerberos.NET.Transport { public class ClientDomainService { + private static readonly Random Random = new(); + public ClientDomainService(ILoggerFactory logger) { this.logger = logger.CreateLoggerSafe(); @@ -47,6 +51,13 @@ static ClientDomainService() public Krb5Config Configuration { get; set; } + public TimeSpan ConnectTimeout { get; set; } = TimeSpan.FromSeconds(2); + + public TimeSpan SendTimeout { get; set; } = TimeSpan.FromSeconds(10); + + public TimeSpan ReceiveTimeout { get; set; } = TimeSpan.FromSeconds(10); + + public void ResetConnections() { DomainCache.Clear(); @@ -59,7 +70,37 @@ public virtual async Task> LocateKdc(string domain, strin { var results = await this.Query(domain, servicePrefix, DefaultKerberosPort); - return ParseQuerySrvReply(results); + results = ParseQuerySrvReply(results); + + return await WeightResults(results); + } + + private async Task> WeightResults(IEnumerable results) + { + SortedList fastest = new(); + + if (this.Configuration.Defaults.PrioritizeKdcByPing) + { + try + { + using var cts = new CancellationTokenSource(this.ConnectTimeout); + + fastest = await results.GetFastestAsync(PingAsync, cts.Token); + } + catch (Exception ex) + { + this.logger.LogWarning(ex, "Ping failed for all found services"); + } + } + + foreach (var r in results) + { + var speed = fastest.FirstOrDefault(f => string.Equals(f.Value.Target, r.Target, StringComparison.OrdinalIgnoreCase)); + + r.PingResponseTime = speed.Value != null ? speed.Key : Random.Next(fastest.Count, int.MaxValue); + } + + return results; } public virtual async Task> LocateKpasswd(string domain, string servicePrefix) @@ -153,6 +194,43 @@ protected virtual async Task> Query(string domain, string return records; } + protected virtual async Task PingAsync(DnsRecord record, CancellationToken cancellationToken) + { + using var ping = new Ping(); + + cancellationToken.Register(() => ping.SendAsyncCancel()); + + var reply = await ping.SendPingAsync(record.Target, Convert.ToInt32(this.ConnectTimeout.TotalMilliseconds)); + + return reply.Status == IPStatus.Success ? record : throw new PingException($"Ping {record.Target} returned {reply.Status}"); + } + + private class DnsRecordComparer : IEqualityComparer + { + public static readonly DnsRecordComparer Instance = new(); + + private DnsRecordComparer() + { + } + + public bool Equals(DnsRecord x, DnsRecord y) + { + if (ReferenceEquals(x, y)) return true; + if (x is null) return false; + if (y is null) return false; + if (x.GetType() != y.GetType()) return false; + return x.Target == y.Target && x.Port == y.Port; + } + + public int GetHashCode(DnsRecord obj) + { + unchecked + { + return ((obj.Target != null ? obj.Target.GetHashCode() : 0) * 397) ^ obj.Port; + } + } + } + private async Task QueryDns(string domain, string servicePrefix, List records) { var lookup = Invariant($"{servicePrefix}.{domain}"); diff --git a/Kerberos.NET/Client/Transport/HttpsKerberosTransport.cs b/Kerberos.NET/Client/Transport/HttpsKerberosTransport.cs index 4b2c8e0d..d953ab47 100644 --- a/Kerberos.NET/Client/Transport/HttpsKerberosTransport.cs +++ b/Kerberos.NET/Client/Transport/HttpsKerberosTransport.cs @@ -19,8 +19,6 @@ namespace Kerberos.NET.Transport { public class HttpsKerberosTransport : KerberosTransportBase { - private static readonly Random Random = new Random(); - private readonly ILogger logger; public HttpsKerberosTransport(ILoggerFactory logger = null) diff --git a/Kerberos.NET/Client/Transport/KerberosTransportBase.cs b/Kerberos.NET/Client/Transport/KerberosTransportBase.cs index 919188cb..48512759 100644 --- a/Kerberos.NET/Client/Transport/KerberosTransportBase.cs +++ b/Kerberos.NET/Client/Transport/KerberosTransportBase.cs @@ -6,7 +6,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Net.NetworkInformation; using System.Threading; using System.Threading.Tasks; using Kerberos.NET.Asn1; @@ -20,14 +19,17 @@ namespace Kerberos.NET.Transport { public abstract class KerberosTransportBase : IKerberosTransport2, IDisposable { + protected static readonly Random Random = new(); + + private bool disposedValue; + protected KerberosTransportBase(ILoggerFactory logger) { this.ClientRealmService = new ClientDomainService(logger); + this.Logger = logger.CreateLoggerSafe(); } - private bool disposedValue; - - private DnsRecord fastest; + protected ILogger Logger { get; } public virtual bool TransportFailed { get; set; } @@ -35,11 +37,23 @@ protected KerberosTransportBase(ILoggerFactory logger) public bool Enabled { get; set; } - public TimeSpan ConnectTimeout { get; set; } = TimeSpan.FromSeconds(2); + public TimeSpan ConnectTimeout + { + get => this.ClientRealmService.ConnectTimeout; + set => this.ClientRealmService.ConnectTimeout = value; + } - public TimeSpan SendTimeout { get; set; } = TimeSpan.FromSeconds(10); + public TimeSpan SendTimeout + { + get => this.ClientRealmService.SendTimeout; + set => this.ClientRealmService.SendTimeout = value; + } - public TimeSpan ReceiveTimeout { get; set; } = TimeSpan.FromSeconds(10); + public TimeSpan ReceiveTimeout + { + get => this.ClientRealmService.ReceiveTimeout; + set => this.ClientRealmService.ReceiveTimeout = value; + } public int MaximumAttempts { get; set; } = 30; @@ -166,58 +180,20 @@ public void Dispose() protected virtual async Task LocatePreferredKdc(string domain, string servicePrefix) { var results = await this.LocateKdc(domain, servicePrefix); - return await SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKerberosPort); + return SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKerberosPort); } protected virtual async Task LocatePreferredKpasswd(string domain, string servicePrefix) { var results = await this.LocateKpasswd(domain, servicePrefix); - return await SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKpasswdPort); - } - - protected virtual async Task SelectedPreferredInstance(string domain, string servicePrefix, IEnumerable results, int defaultPort) - { - if (results.Contains(fastest, DnsRecordComparer.Instance)) - { - return fastest; - } - - fastest = await results.Where(r => r.Name.StartsWith(servicePrefix)).GetFastestAsync(PingAsync); - return fastest ?? throw new KerberosTransportException($"Cannot locate SRV record for {domain}"); - } - - private async Task PingAsync(DnsRecord record, CancellationToken cancellationToken) - { - using var ping = new Ping(); - cancellationToken.Register(() => ping.SendAsyncCancel()); - var reply = await ping.SendPingAsync(record.Target, Convert.ToInt32(ConnectTimeout.TotalMilliseconds)); - return reply.Status == IPStatus.Success ? record : throw new PingException($"Ping {record.Target} returned {reply.Status}"); + return SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKpasswdPort); } - private class DnsRecordComparer : IEqualityComparer + protected virtual DnsRecord SelectedPreferredInstance(string domain, string servicePrefix, IEnumerable results, int defaultPort) { - public static readonly DnsRecordComparer Instance = new(); - - private DnsRecordComparer() - { - } + results = results.Where(r => r.Name.StartsWith(servicePrefix)).OrderBy(r => r.PingResponseTime); - public bool Equals(DnsRecord x, DnsRecord y) - { - if (ReferenceEquals(x, y)) return true; - if (x is null) return false; - if (y is null) return false; - if (x.GetType() != y.GetType()) return false; - return x.Target == y.Target && x.Port == y.Port; - } - - public int GetHashCode(DnsRecord obj) - { - unchecked - { - return ((obj.Target != null ? obj.Target.GetHashCode() : 0) * 397) ^ obj.Port; - } - } + return results.FirstOrDefault() ?? throw new KerberosTransportException($"Cannot locate SRV record for {domain}"); } } } diff --git a/Kerberos.NET/Configuration/Krb5ConfigDefaults.cs b/Kerberos.NET/Configuration/Krb5ConfigDefaults.cs index 886d4da1..056b6b7f 100644 --- a/Kerberos.NET/Configuration/Krb5ConfigDefaults.cs +++ b/Kerberos.NET/Configuration/Krb5ConfigDefaults.cs @@ -346,5 +346,12 @@ public class Krb5ConfigDefaults : Krb5ConfigObject [DefaultValue(PrincipalNameType.NT_ENTERPRISE)] [DisplayName("default_name_type")] public PrincipalNameType DefaultNameType { get; set; } + + /// + /// Indicates whether the client should try to find and sort KDCs by how long it takes for them to respond by ping. + /// + [DefaultValue(true)] + [DisplayName("prioritize_by_response_time")] + public bool PrioritizeKdcByPing { get; set; } } } diff --git a/Kerberos.NET/Dns/DnsRecord.cs b/Kerberos.NET/Dns/DnsRecord.cs index 4f79c51a..0a43b66c 100644 --- a/Kerberos.NET/Dns/DnsRecord.cs +++ b/Kerberos.NET/Dns/DnsRecord.cs @@ -49,5 +49,7 @@ public string Address return this.Target; } } + + public int PingResponseTime { get; set; } = int.MaxValue; } } diff --git a/Kerberos.NET/TaskExtensions.cs b/Kerberos.NET/TaskExtensions.cs index 35cc3216..e6a340f5 100644 --- a/Kerberos.NET/TaskExtensions.cs +++ b/Kerberos.NET/TaskExtensions.cs @@ -1,4 +1,4 @@ -// ----------------------------------------------------------------------- +// ----------------------------------------------------------------------- // Licensed to The .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // ----------------------------------------------------------------------- @@ -11,31 +11,50 @@ internal static class TaskExtensions { - public static async Task GetFastestAsync(this IEnumerable source, Func> task, CancellationToken cancellationToken = default) + public static async Task> GetFastestAsync( + this IEnumerable source, + Func> task, + CancellationToken cancellationToken = default + ) { using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); var tasks = new HashSet>(source.Select(e => task(e, cts.Token))); + if (tasks.Count == 0) { - return default; + return new(); } + int next = 0; + SortedList results = new(); + var exceptions = new List(); + do { var completedTask = await Task.WhenAny(tasks); + if (completedTask.Status == TaskStatus.RanToCompletion) { cts.Cancel(); - return completedTask.Result; + + results.Add(++next, completedTask.Result); } if (completedTask.Exception != null) { exceptions.AddRange(completedTask.Exception.InnerExceptions); } + tasks.Remove(completedTask); - } while (tasks.Count > 0); + + } + while (tasks.Count > 0); + + if (results.Count > 0) + { + return results; + } throw new AggregateException(exceptions); } diff --git a/Tests/Tests.Kerberos.NET/End2End/TcpClientServerTests.cs b/Tests/Tests.Kerberos.NET/End2End/TcpClientServerTests.cs index 20495727..6f8c6184 100644 --- a/Tests/Tests.Kerberos.NET/End2End/TcpClientServerTests.cs +++ b/Tests/Tests.Kerberos.NET/End2End/TcpClientServerTests.cs @@ -5,10 +5,17 @@ using System; using System.Linq; +using System.Net.NetworkInformation; +using System.Threading; using System.Threading.Tasks; using Kerberos.NET.Client; +using Kerberos.NET.Configuration; using Kerberos.NET.Credentials; +using Kerberos.NET.Dns; using Kerberos.NET.Transport; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.VisualBasic.Logging; using Microsoft.VisualStudio.TestTools.UnitTesting; using static Tests.Kerberos.NET.KdcListener; @@ -88,6 +95,119 @@ await RequestAndValidateTickets( } } + [TestMethod] + public async Task ClientConnectsWithoutPing() + { + var port = NextPort(); + + using (var listener = StartTcpListener(port)) + { + _ = listener.Start(); + + using (var client = new KerberosClient() + { + ConnectTimeout = TimeSpan.FromMilliseconds(1) + }) + { + client.Configuration.Defaults.PrioritizeKdcByPing = false; + + client.PinKdc("corp.identityintervention.com", $"127.0.0.1:{port}"); + + try + { + await client.Authenticate(new KerberosPasswordCredential(AdminAtCorpUserName, FakeAdminAtCorpPassword)); + } + catch (AggregateException agg) + { + throw agg.InnerExceptions.First(); + } + } + } + } + + [TestMethod] + public async Task ClientConnectsWithExplodingPing() + { + var port = NextPort(); + + using (var listener = StartTcpListener(port)) + { + _ = listener.Start(); + + using (var client = new KerberosClient(transports: new PingTransport(NullLoggerFactory.Instance, port) { BlockPing = true }) + { + ConnectTimeout = TimeSpan.FromMilliseconds(1) + }) + { + Assert.IsTrue(client.Configuration.Defaults.PrioritizeKdcByPing); + + client.PinKdc("corp.identityintervention.com", $"127.0.0.1:{port}"); + + try + { + await client.Authenticate(new KerberosPasswordCredential(AdminAtCorpUserName, FakeAdminAtCorpPassword)); + } + catch (AggregateException agg) + { + throw agg.InnerExceptions.First(); + } + } + } + } + + private class PingTransport : TcpKerberosTransport + { + private readonly ILoggerFactory log; + private readonly int port; + + public PingTransport(ILoggerFactory logger, int port) : base(logger) + { + this.log = logger; + this.port = port; + } + + private ExplodyClientRealmService crs; + + public override ClientDomainService ClientRealmService => crs ??= new ExplodyClientRealmService(log) + { + BlockPing = this.BlockPing, + Configuration = Krb5Config.Default() + }; + + public bool BlockPing { get; set; } + + public bool DontPing + { + get => this.Configuration.Defaults.PrioritizeKdcByPing; + set => this.Configuration.Defaults.PrioritizeKdcByPing = value; + } + + //protected override Task LocatePreferredKdc(string domain, string servicePrefix) + //{ + // return Task.FromResult(new DnsRecord { Target = "127.0.0.1", Port = this.port }); + //} + + private class ExplodyClientRealmService : ClientDomainService + { + public ExplodyClientRealmService(ILoggerFactory logger) : base(logger) + { + } + + public bool BlockPing { get; set; } + + protected override Task PingAsync(DnsRecord record, CancellationToken cancellationToken) + { + if (BlockPing) + { + throw new PingException("Goes bang"); + } + + return base.PingAsync(record, cancellationToken); + } + } + } + + [TestMethod] public async Task TCP_MultithreadedClient() {