diff --git a/src/Mono.Android/Xamarin.Android.Net/AndroidMessageHandler.cs b/src/Mono.Android/Xamarin.Android.Net/AndroidMessageHandler.cs index e23258fa78f..97283a044bf 100644 --- a/src/Mono.Android/Xamarin.Android.Net/AndroidMessageHandler.cs +++ b/src/Mono.Android/Xamarin.Android.Net/AndroidMessageHandler.cs @@ -202,16 +202,16 @@ void DisposeStream () stream.Dispose (); } - static bool ShouldMapToCancellation (Exception ex, CancellationToken cancellationToken) - { - return cancellationToken.IsCancellationRequested && - ex is global::System.IO.IOException - or Java.IO.IOException - or InvalidDataException - or ObjectDisposedException - or WebException; - } + } + static bool ShouldMapToCancellation (Exception ex, CancellationToken cancellationToken) + { + return cancellationToken.IsCancellationRequested && + ex is global::System.IO.IOException + or Java.IO.IOException + or InvalidDataException + or ObjectDisposedException + or WebException; } internal const string LOG_APP = "monodroid-net"; @@ -788,6 +788,11 @@ protected virtual async Task WriteRequestContentToOutput (HttpRequestMessage req var stream = await request.Content.ReadAsStreamAsync ().ConfigureAwait (false); try { await stream.CopyToAsync(httpConnection.OutputStream!, 4096, cancellationToken).ConfigureAwait(false); + } catch (Exception ex) when (ShouldMapToCancellation (ex, cancellationToken)) { + // When the caller cancels the request while the body is being uploaded, the connection + // is disconnected which surfaces as a transport exception (e.g. "Socket closed"). Map it + // to an OperationCanceledException so callers observe cancellation instead of a WebException. + throw new System.OperationCanceledException ("Request body upload was canceled.", ex, cancellationToken); } finally { // // Rewind the stream to beginning in case the HttpContent implementation diff --git a/tests/Mono.Android-Tests/Mono.Android-Tests/Xamarin.Android.Net/AndroidMessageHandlerCancellationTests.cs b/tests/Mono.Android-Tests/Mono.Android-Tests/Xamarin.Android.Net/AndroidMessageHandlerCancellationTests.cs index 0d6e5439a9a..5e01ba864d7 100644 --- a/tests/Mono.Android-Tests/Mono.Android-Tests/Xamarin.Android.Net/AndroidMessageHandlerCancellationTests.cs +++ b/tests/Mono.Android-Tests/Mono.Android-Tests/Xamarin.Android.Net/AndroidMessageHandlerCancellationTests.cs @@ -1,6 +1,7 @@ #nullable enable using System; +using System.IO; using System.Net; using System.Net.Http; using System.Net.Sockets; @@ -19,6 +20,7 @@ namespace Xamarin.Android.NetTests public class AndroidMessageHandlerCancellationTests { const int StalledResponseContentLength = 1024 * 1024; + const int UploadContentLength = 16 * 1024 * 1024; const int BodyReadBlockDelayMilliseconds = 250; const int PromptCancellationTimeoutMilliseconds = 3000; @@ -58,6 +60,27 @@ public async Task ResponseContentReadBodyReadCancellationIsPrompt () await AssertCanceledPromptly (readTask, server.ReleaseResponseBody).ConfigureAwait (false); } + [Test] + public async Task RequestBodyUploadCancellationIsPrompt () + { + using var uploadServer = new StalledRequestServer (); + using var handler = new AndroidMessageHandler (); + using var client = new HttpClient (handler); + using var cts = new CancellationTokenSource (); + using var request = new HttpRequestMessage (HttpMethod.Put, $"http://localhost:{uploadServer.Port}/upload") { + // A large body ensures the socket send buffer fills while the server stalls reading it, + // so the upload is still in progress when the caller cancels. The content streams the + // bytes in small chunks instead of allocating the whole body up front. + Content = new StreamingContent (UploadContentLength), + }; + + Task sendTask = client.SendAsync (request, HttpCompletionOption.ResponseHeadersRead, cts.Token); + + await WaitForBodyReadToBlock (uploadServer.BodyStartedTask).ConfigureAwait (false); + cts.Cancel (); + await AssertCanceledPromptly (sendTask, uploadServer.ReleaseRequestBody).ConfigureAwait (false); + } + [Test] public async Task ResponseHeadersReadBodyReadCancellationIsPrompt () { @@ -199,5 +222,156 @@ async Task ObserveServerTask () await serverTask.ConfigureAwait (false); } } + + sealed class StreamingContent : HttpContent + { + readonly long length; + + public StreamingContent (long length) + { + this.length = length; + } + + protected override Task CreateContentReadStreamAsync () + { + // AndroidMessageHandler uses ReadAsStreamAsync () before uploading; override this to avoid + // HttpContent's default full-buffering behavior for custom content. + return Task.FromResult (new ZeroStream (length)); + } + + protected override async Task SerializeToStreamAsync (Stream stream, System.Net.TransportContext? context) + { + using var source = new ZeroStream (length); + await source.CopyToAsync (stream, 4096).ConfigureAwait (false); + } + + protected override bool TryComputeLength (out long computedLength) + { + computedLength = length; + return true; + } + + sealed class ZeroStream : Stream + { + readonly long length; + long position; + + public ZeroStream (long length) + { + this.length = length; + } + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => false; + public override long Length => length; + + public override long Position { + get => position; + set => throw new NotSupportedException (); + } + + public override int Read (byte [] buffer, int offset, int count) + { + if (position >= length) + return 0; + + int toRead = (int) Math.Min (count, length - position); + Array.Clear (buffer, offset, toRead); + position += toRead; + return toRead; + } + + public override ValueTask ReadAsync (Memory buffer, CancellationToken cancellationToken = default) + { + if (position >= length) + return new ValueTask (0); + + int toRead = (int) Math.Min (buffer.Length, length - position); + buffer.Span.Slice (0, toRead).Clear (); + position += toRead; + return new ValueTask (toRead); + } + + public override void Flush () { } + public override long Seek (long offset, SeekOrigin origin) => throw new NotSupportedException (); + public override void SetLength (long value) => throw new NotSupportedException (); + public override void Write (byte [] buffer, int offset, int count) => throw new NotSupportedException (); + } + } + + sealed class StalledRequestServer : IDisposable + { + readonly TcpListener listener; + readonly TaskCompletionSource bodyStarted = new TaskCompletionSource (TaskCreationOptions.RunContinuationsAsynchronously); + readonly TaskCompletionSource releaseBody = new TaskCompletionSource (TaskCreationOptions.RunContinuationsAsynchronously); + readonly Task serverTask; + + public StalledRequestServer () + { + listener = new TcpListener (IPAddress.Loopback, 0); + listener.Start (); + Port = ((IPEndPoint) listener.LocalEndpoint).Port; + + serverTask = StallRequestBody (); + } + + public int Port { get; } + + public Task BodyStartedTask => bodyStarted.Task; + + public void ReleaseRequestBody () + { + releaseBody.TrySetResult (true); + } + + public void Dispose () + { + ReleaseRequestBody (); + try { + listener.Stop (); + } catch (Exception ex) { + Console.WriteLine ($"Exception while stopping the stalled request server: {ex}"); + } + } + + async Task StallRequestBody () + { + try { + using var client = await listener.AcceptTcpClientAsync ().ConfigureAwait (false); + using var stream = client.GetStream (); + + // Read just the request headers so the upload phase begins, then stop reading the body + // to keep the socket send buffer full on the client side until released. + await ReadRequestHeaders (stream).ConfigureAwait (false); + bodyStarted.TrySetResult (true); + + await releaseBody.Task.ConfigureAwait (false); + } catch (Exception ex) { + if (!BodyStartedTask.IsCompleted) { + bodyStarted.TrySetException (ex); + return; + } + Console.WriteLine ($"Exception while stalling the request body: {ex}"); + } + } + + static async Task ReadRequestHeaders (NetworkStream stream) + { + var buffer = new byte [1]; + int consecutiveLineEndChars = 0; + while (consecutiveLineEndChars < 4) { + int read = await stream.ReadAsync (buffer, 0, 1).ConfigureAwait (false); + if (read == 0) + break; + + byte b = buffer [0]; + if (b == '\r' || b == '\n') + consecutiveLineEndChars++; + else + consecutiveLineEndChars = 0; + } + } + } } }