From 0ca5017097c8b34e4f464985c51eb515b2bac6da Mon Sep 17 00:00:00 2001 From: Cy Okeke Date: Mon, 16 Mar 2026 12:20:31 +0100 Subject: [PATCH 1/8] Implement the portal session url --- .../VNext/AccountBillingVNextController.cs | 21 + .../Requests/Portal/PortalSessionRequest.cs | 39 ++ .../Responses/Portal/PortalSessionResponse.cs | 13 + .../Extensions/ServiceCollectionExtensions.cs | 2 + .../CreateBillingPortalSessionCommand.cs | 93 ++++ src/Core/Billing/Services/IStripeAdapter.cs | 2 + .../Services/Implementations/StripeAdapter.cs | 9 + .../AccountBillingVNextControllerTests.cs | 145 ++++++ .../Portal/PortalSessionRequestTests.cs | 119 +++++ .../CreateBillingPortalSessionCommandTests.cs | 415 ++++++++++++++++++ 10 files changed, 858 insertions(+) create mode 100644 src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs create mode 100644 src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs create mode 100644 src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs create mode 100644 test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs create mode 100644 test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs diff --git a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs index 9facdd9b24de..17134cb171bf 100644 --- a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs +++ b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs @@ -1,11 +1,14 @@ using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Api.Billing.Models.Requests.Portal; using Bit.Api.Billing.Models.Requests.Premium; using Bit.Api.Billing.Models.Requests.Storage; +using Bit.Api.Billing.Models.Responses.Portal; using Bit.Core; using Bit.Core.Billing.Licenses.Queries; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Portal.Commands; using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Subscriptions.Commands; using Bit.Core.Billing.Subscriptions.Queries; @@ -21,6 +24,7 @@ namespace Bit.Api.Billing.Controllers.VNext; [Route("account/billing/vnext")] [SelfHosted(NotSelfHostedOnly = true)] public class AccountBillingVNextController( + ICreateBillingPortalSessionCommand createBillingPortalSessionCommand, ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand, ICreatePremiumCloudHostedSubscriptionCommand createPremiumCloudHostedSubscriptionCommand, IGetBitwardenSubscriptionQuery getBitwardenSubscriptionQuery, @@ -147,4 +151,21 @@ public async Task GetApplicableDiscountsAsync( return Handle(result); } + /// + /// Creates a Stripe billing portal session for the authenticated user. + /// The portal allows users to manage their subscription, payment methods, and billing history. + /// + /// The authenticated user + /// Portal session configuration including return URL + /// Portal session URL for redirection + [HttpPost("portal-session")] + [InjectUser] + public async Task CreatePortalSessionAsync( + [BindNever] User user, + [FromBody] PortalSessionRequest request) + { + var result = await createBillingPortalSessionCommand.Run(user, request.ReturnUrl!); + return Handle(result.Map(url => new PortalSessionResponse { Url = url })); + } + } diff --git a/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs b/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs new file mode 100644 index 000000000000..e65ce5c03bda --- /dev/null +++ b/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs @@ -0,0 +1,39 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Api.Billing.Models.Requests.Portal; + +/// +/// Request model for creating a Stripe billing portal session. +/// +public class PortalSessionRequest : IValidatableObject +{ + /// + /// The URL to redirect to after the user completes their session in the billing portal. + /// Must be a valid HTTP(S) URL. + /// + [Required] + [MaxLength(2000)] + public string? ReturnUrl { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (!string.IsNullOrWhiteSpace(ReturnUrl)) + { + if (!Uri.TryCreate(ReturnUrl, UriKind.Absolute, out var uri)) + { + yield return new ValidationResult( + "Return URL must be a valid absolute URL.", + [nameof(ReturnUrl)]); + yield break; + } + + // Prevent open redirect vulnerabilities by restricting to HTTP(S) schemes + if (uri.Scheme != Uri.UriSchemeHttp && uri.Scheme != Uri.UriSchemeHttps) + { + yield return new ValidationResult( + "Return URL must use HTTP or HTTPS scheme.", + [nameof(ReturnUrl)]); + } + } + } +} diff --git a/src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs b/src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs new file mode 100644 index 000000000000..8ec1a47dc375 --- /dev/null +++ b/src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs @@ -0,0 +1,13 @@ +namespace Bit.Api.Billing.Models.Responses.Portal; + +/// +/// Response model containing the Stripe billing portal session URL. +/// +public class PortalSessionResponse +{ + /// + /// The URL to redirect the user to for accessing the Stripe billing portal. + /// This URL is time-limited and single-use. + /// + public required string Url { get; init; } +} diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index d7146455900c..4c929dc4b462 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -4,6 +4,7 @@ using Bit.Core.Billing.Organizations.Queries; using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Payment; +using Bit.Core.Billing.Portal.Commands; using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Premium.Queries; using Bit.Core.Billing.Pricing; @@ -44,6 +45,7 @@ public static void AddBillingOperations(this IServiceCollection services) services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); } private static void AddOrganizationLicenseCommandsQueries(this IServiceCollection services) diff --git a/src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs b/src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs new file mode 100644 index 000000000000..90830708527d --- /dev/null +++ b/src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs @@ -0,0 +1,93 @@ +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Microsoft.Extensions.Logging; +using Stripe; +using Stripe.BillingPortal; + +namespace Bit.Core.Billing.Portal.Commands; + +using static StripeConstants; + +public interface ICreateBillingPortalSessionCommand +{ + Task> Run(User user, string returnUrl); +} + +public class CreateBillingPortalSessionCommand( + ILogger logger, + IStripeAdapter stripeAdapter) + : BaseBillingCommand(logger), ICreateBillingPortalSessionCommand +{ + private readonly ILogger _logger = logger; + + protected override Conflict DefaultConflict => + new("Unable to create billing portal session. Please contact support for assistance."); + + public Task> Run(User user, string returnUrl) => + HandleAsync(async () => + { + if (string.IsNullOrEmpty(user.GatewayCustomerId)) + { + _logger.LogWarning("{Command}: User ({UserId}) does not have a Stripe customer ID", + CommandName, user.Id); + return new BadRequest("User does not have a Stripe customer ID."); + } + + if (string.IsNullOrEmpty(user.GatewaySubscriptionId)) + { + _logger.LogWarning("{Command}: User ({UserId}) does not have a subscription", + CommandName, user.Id); + return new BadRequest("User does not have a Premium subscription."); + } + + // Fetch the subscription to validate its status + Subscription subscription; + try + { + subscription = await stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId); + } + catch (StripeException stripeException) + { + _logger.LogError(stripeException, + "{Command}: Failed to fetch subscription ({SubscriptionId}) for user ({UserId})", + CommandName, user.GatewaySubscriptionId, user.Id); + return new BadRequest("Unable to verify subscription status."); + } + + if (subscription == null) + { + _logger.LogWarning("{Command}: Subscription ({SubscriptionId}) for user ({UserId}) was not found", + CommandName, user.GatewaySubscriptionId, user.Id); + return new BadRequest("User subscription not found."); + } + + // Only allow portal access for active or past_due subscriptions + if (subscription.Status != SubscriptionStatus.Active && subscription.Status != SubscriptionStatus.PastDue) + { + _logger.LogWarning( + "{Command}: User ({UserId}) subscription ({SubscriptionId}) has status '{Status}' which is not eligible for portal access", + CommandName, user.Id, user.GatewaySubscriptionId, subscription.Status); + return new BadRequest("Your subscription cannot be managed in its current status."); + } + + var options = new SessionCreateOptions + { + Customer = user.GatewayCustomerId, + ReturnUrl = returnUrl + }; + + var session = await stripeAdapter.CreateBillingPortalSessionAsync(options); + + if (session?.Url == null) + { + return DefaultConflict; + } + + _logger.LogInformation("{Command}: Successfully created billing portal session for user ({UserId})", + CommandName, user.Id); + + return session.Url; + }); +} diff --git a/src/Core/Billing/Services/IStripeAdapter.cs b/src/Core/Billing/Services/IStripeAdapter.cs index d7d14432caf9..4c18837f2980 100644 --- a/src/Core/Billing/Services/IStripeAdapter.cs +++ b/src/Core/Billing/Services/IStripeAdapter.cs @@ -3,6 +3,7 @@ using Bit.Core.Models.BitStripe; using Stripe; +using Stripe.BillingPortal; using Stripe.Tax; namespace Bit.Core.Billing.Services; @@ -52,4 +53,5 @@ Task CreateCustomerBalanceTransactionAsync(string cu Task GetCouponAsync(string couponId, CouponGetOptions options = null); Task> ListProductsAsync(ProductListOptions options = null); Task> ListSubscriptionsAsync(SubscriptionListOptions options = null); + Task CreateBillingPortalSessionAsync(SessionCreateOptions options); } diff --git a/src/Core/Billing/Services/Implementations/StripeAdapter.cs b/src/Core/Billing/Services/Implementations/StripeAdapter.cs index 5672c6ca4d0f..866726b0bad7 100644 --- a/src/Core/Billing/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Billing/Services/Implementations/StripeAdapter.cs @@ -4,6 +4,7 @@ using Bit.Core.Models.BitStripe; using Stripe; +using Stripe.BillingPortal; using Stripe.Tax; using Stripe.TestHelpers; using CustomerService = Stripe.CustomerService; @@ -29,6 +30,7 @@ public class StripeAdapter : IStripeAdapter private readonly RegistrationService _taxRegistrationService; private readonly CouponService _couponService; private readonly ProductService _productService; + private readonly SessionService _billingPortalSessionService; public StripeAdapter() { @@ -48,6 +50,7 @@ public StripeAdapter() _taxRegistrationService = new RegistrationService(); _couponService = new CouponService(); _productService = new ProductService(); + _billingPortalSessionService = new SessionService(); } /************** @@ -234,4 +237,10 @@ public async Task> ListProductsAsync(ProductListOptions options = ****************/ public Task> ListSubscriptionsAsync(SubscriptionListOptions options = null) => _subscriptionService.ListAsync(options); + + /********************** + ** BILLING PORTAL ** + **********************/ + public Task CreateBillingPortalSessionAsync(SessionCreateOptions options) => + _billingPortalSessionService.CreateAsync(options); } diff --git a/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs b/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs index 706b3ae21967..67c01ca3a3da 100644 --- a/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs @@ -1,10 +1,12 @@ using Bit.Api.Billing.Controllers.VNext; +using Bit.Api.Billing.Models.Requests.Portal; using Bit.Api.Billing.Models.Requests.Storage; using Bit.Core.Billing.Commands; using Bit.Core.Billing.Licenses.Queries; using Bit.Core.Billing.Models.Api.Response; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Portal.Commands; using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Subscriptions.Commands; using Bit.Core.Billing.Subscriptions.Queries; @@ -14,8 +16,10 @@ using Microsoft.AspNetCore.Http.HttpResults; using NSubstitute; using OneOf.Types; +using Stripe; using Xunit; using BadRequest = Bit.Core.Billing.Commands.BadRequest; +using Conflict = Bit.Core.Billing.Commands.Conflict; namespace Bit.Api.Test.Billing.Controllers.VNext; @@ -25,6 +29,7 @@ public class AccountBillingVNextControllerTests private readonly IGetUserLicenseQuery _getUserLicenseQuery; private readonly IUpgradePremiumToOrganizationCommand _upgradePremiumToOrganizationCommand; private readonly IGetApplicableDiscountsQuery _getApplicableDiscountsQuery; + private readonly ICreateBillingPortalSessionCommand _createBillingPortalSessionCommand; private readonly AccountBillingVNextController _sut; public AccountBillingVNextControllerTests() @@ -33,8 +38,10 @@ public AccountBillingVNextControllerTests() _getUserLicenseQuery = Substitute.For(); _upgradePremiumToOrganizationCommand = Substitute.For(); _getApplicableDiscountsQuery = Substitute.For(); + _createBillingPortalSessionCommand = Substitute.For(); _sut = new AccountBillingVNextController( + _createBillingPortalSessionCommand, Substitute.For(), Substitute.For(), Substitute.For(), @@ -291,4 +298,142 @@ public async Task GetApplicableDiscountsAsync_EligibleDiscounts_ReturnsOkWithDis Assert.Equal(models, okResult.Value); await _getApplicableDiscountsQuery.Received(1).Run(user); } + + [Theory, BitAutoData] + public async Task CreatePortalSessionAsync_Success_ReturnsPortalUrlAsync(User user, string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + var portalUrl = "https://billing.stripe.com/session/test123"; + + _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + .Returns(new BillingCommandResult(portalUrl)); + + // Act + var result = await _sut.CreatePortalSessionAsync(user, request); + + // Assert + Assert.IsAssignableFrom(result); + await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + } + + [Theory, BitAutoData] + public async Task CreatePortalSessionAsync_NoCustomerId_ReturnsBadRequestAsync(User user, string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + + _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + .Returns(new BillingCommandResult(new BadRequest("User does not have a Stripe customer ID."))); + + // Act + var result = await _sut.CreatePortalSessionAsync(user, request); + + // Assert + Assert.IsAssignableFrom(result); + await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + } + + [Theory, BitAutoData] + public async Task CreatePortalSessionAsync_NoSubscriptionId_ReturnsBadRequestAsync(User user, string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + + _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + .Returns(new BillingCommandResult(new BadRequest("User does not have a Premium subscription."))); + + // Act + var result = await _sut.CreatePortalSessionAsync(user, request); + + // Assert + Assert.IsAssignableFrom(result); + await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + } + + [Theory, BitAutoData] + public async Task CreatePortalSessionAsync_InvalidSubscriptionStatus_ReturnsBadRequestAsync(User user, string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + + _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + .Returns(new BillingCommandResult(new BadRequest("Your subscription cannot be managed in its current status."))); + + // Act + var result = await _sut.CreatePortalSessionAsync(user, request); + + // Assert + Assert.IsAssignableFrom(result); + await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + } + + [Theory, BitAutoData] + public async Task CreatePortalSessionAsync_SubscriptionNotFound_ReturnsBadRequestAsync(User user, string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + + _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + .Returns(new BillingCommandResult(new BadRequest("User subscription not found."))); + + // Act + var result = await _sut.CreatePortalSessionAsync(user, request); + + // Assert + Assert.IsAssignableFrom(result); + await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + } + + [Theory, BitAutoData] + public async Task CreatePortalSessionAsync_StripeException_ReturnsServerErrorAsync(User user, string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + var exception = new StripeException("Stripe API error"); + + _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + .Returns(new BillingCommandResult(new Unhandled(exception))); + + // Act + var result = await _sut.CreatePortalSessionAsync(user, request); + + // Assert + Assert.IsAssignableFrom(result); + await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + } + + [Theory, BitAutoData] + public async Task CreatePortalSessionAsync_SessionWithNullUrl_ReturnsServerErrorAsync(User user, string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + + _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + .Returns(new BillingCommandResult(new Conflict("Unable to create billing portal session. Please contact support for assistance."))); + + // Act + var result = await _sut.CreatePortalSessionAsync(user, request); + + // Assert + Assert.IsAssignableFrom(result); + await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + } + + [Theory, BitAutoData] + public async Task CreatePortalSessionAsync_NullSession_ReturnsServerErrorAsync(User user, string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + + _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + .Returns(new BillingCommandResult(new Conflict("Unable to create billing portal session. Please contact support for assistance."))); + + // Act + var result = await _sut.CreatePortalSessionAsync(user, request); + + // Assert + Assert.IsAssignableFrom(result); + await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + } } diff --git a/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs b/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs new file mode 100644 index 000000000000..29f772268af5 --- /dev/null +++ b/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs @@ -0,0 +1,119 @@ +using Bit.Api.Billing.Models.Requests.Portal; +using System.ComponentModel.DataAnnotations; +using Xunit; + +namespace Bit.Api.Test.Billing.Models.Requests.Portal; + +public class PortalSessionRequestTests +{ + [Theory] + [InlineData("https://example.com/return")] + [InlineData("http://localhost:3000/billing")] + [InlineData("https://app.bitwarden.com/settings/billing")] + public void Validate_ValidHttpsUrl_ReturnsNoErrors(string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = returnUrl }; + var context = new ValidationContext(request); + var results = new List(); + + // Act + var isValid = Validator.TryValidateObject(request, context, results, true); + + // Assert + Assert.True(isValid); + Assert.Empty(results); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + public void Validate_MissingReturnUrl_ReturnsRequiredError(string? returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = returnUrl }; + var context = new ValidationContext(request); + var results = new List(); + + // Act + var isValid = Validator.TryValidateObject(request, context, results, true); + + // Assert + Assert.False(isValid); + Assert.Contains(results, r => r.MemberNames.Contains(nameof(PortalSessionRequest.ReturnUrl))); + } + + [Theory] + [InlineData("javascript:alert('xss')")] + [InlineData("data:text/html,")] + [InlineData("file:///etc/passwd")] + [InlineData("ftp://example.com/file")] + public void Validate_NonHttpScheme_ReturnsSchemeError(string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = returnUrl }; + var context = new ValidationContext(request); + var results = new List(); + + // Act + var isValid = Validator.TryValidateObject(request, context, results, true); + + // Assert + Assert.False(isValid); + Assert.Contains(results, r => r.ErrorMessage!.Contains("HTTP or HTTPS")); + } + + [Theory] + [InlineData("not-a-url")] + [InlineData("://invalid")] + [InlineData("http://")] + public void Validate_InvalidUrl_ReturnsUrlFormatError(string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = returnUrl }; + var context = new ValidationContext(request); + var results = new List(); + + // Act + var isValid = Validator.TryValidateObject(request, context, results, true); + + // Assert + Assert.False(isValid); + Assert.NotEmpty(results); + } + + [Fact] + public void Validate_ExcessivelyLongUrl_ReturnsMaxLengthError() + { + // Arrange + var longUrl = "https://example.com/" + new string('a', 2500); + var request = new PortalSessionRequest { ReturnUrl = longUrl }; + var context = new ValidationContext(request); + var results = new List(); + + // Act + var isValid = Validator.TryValidateObject(request, context, results, true); + + // Assert + Assert.False(isValid); + Assert.Contains(results, r => r.MemberNames.Contains(nameof(PortalSessionRequest.ReturnUrl))); + } + + [Theory] + [InlineData(" ")] + [InlineData("\t\n")] + public void Validate_WhitespaceOnlyUrl_ReturnsRequiredError(string returnUrl) + { + // Arrange + var request = new PortalSessionRequest { ReturnUrl = returnUrl }; + var context = new ValidationContext(request); + var results = new List(); + + // Act + var isValid = Validator.TryValidateObject(request, context, results, true); + + // Assert + Assert.False(isValid); + Assert.NotEmpty(results); + } +} diff --git a/test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs b/test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs new file mode 100644 index 000000000000..4dd1f8e51ace --- /dev/null +++ b/test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs @@ -0,0 +1,415 @@ +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Portal.Commands; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Microsoft.Extensions.Logging; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using Stripe; +using Stripe.BillingPortal; +using Xunit; + +namespace Bit.Core.Test.Billing.Portal.Commands; + +using static StripeConstants; + +public class CreateBillingPortalSessionCommandTests +{ + private readonly ILogger _logger = Substitute.For>(); + private readonly IStripeAdapter _stripeAdapter = Substitute.For(); + private readonly CreateBillingPortalSessionCommand _command; + private readonly User _user; + + public CreateBillingPortalSessionCommandTests() + { + _command = new CreateBillingPortalSessionCommand(_logger, _stripeAdapter); + _user = new User + { + Id = Guid.NewGuid(), + Email = "test@example.com", + GatewayCustomerId = "cus_test123", + GatewaySubscriptionId = "sub_test123" + }; + } + + [Fact] + public async Task Run_WithValidUser_ReturnsPortalUrl() + { + // Arrange + var returnUrl = "https://example.com/billing"; + var expectedUrl = "https://billing.stripe.com/session/test123"; + var session = new Session { Url = expectedUrl }; + var subscription = new Subscription { Id = _user.GatewaySubscriptionId, Status = SubscriptionStatus.Active }; + + _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + _stripeAdapter.CreateBillingPortalSessionAsync(Arg.Any()) + .Returns(session); + + // Act + var result = await _command.Run(_user, returnUrl); + + // Assert + Assert.True(result.IsT0); + Assert.Equal(expectedUrl, result.AsT0); + + await _stripeAdapter.Received(1).GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()); + await _stripeAdapter.Received(1).CreateBillingPortalSessionAsync( + Arg.Is(o => + o.Customer == _user.GatewayCustomerId && + o.ReturnUrl == returnUrl)); + + _logger.Received(1).Log( + LogLevel.Information, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Successfully created billing portal session") && o.ToString()!.Contains(_user.Id.ToString())), + Arg.Any(), + Arg.Any>()); + } + + [Fact] + public async Task Run_WithoutGatewayCustomerId_ReturnsBadRequest() + { + // Arrange + var userWithoutCustomerId = new User + { + Id = Guid.NewGuid(), + Email = "test@example.com", + GatewayCustomerId = null + }; + var returnUrl = "https://example.com/billing"; + + // Act + var result = await _command.Run(userWithoutCustomerId, returnUrl); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("User does not have a Stripe customer ID.", badRequest.Response); + + await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); + + _logger.Received(1).Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("does not have a Stripe customer ID") && o.ToString()!.Contains(userWithoutCustomerId.Id.ToString())), + Arg.Any(), + Arg.Any>()); + } + + [Fact] + public async Task Run_WithEmptyGatewayCustomerId_ReturnsBadRequest() + { + // Arrange + var userWithEmptyCustomerId = new User + { + Id = Guid.NewGuid(), + Email = "test@example.com", + GatewayCustomerId = string.Empty + }; + var returnUrl = "https://example.com/billing"; + + // Act + var result = await _command.Run(userWithEmptyCustomerId, returnUrl); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("User does not have a Stripe customer ID.", badRequest.Response); + + await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); + + _logger.Received(1).Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("does not have a Stripe customer ID") && o.ToString()!.Contains(userWithEmptyCustomerId.Id.ToString())), + Arg.Any(), + Arg.Any>()); + } + + [Fact] + public async Task Run_WhenSessionIsNull_ReturnsConflict() + { + // Arrange + var returnUrl = "https://example.com/billing"; + var subscription = new Subscription { Id = _user.GatewaySubscriptionId, Status = SubscriptionStatus.Active }; + + _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + _stripeAdapter.CreateBillingPortalSessionAsync(Arg.Any()) + .Returns((Session?)null); + + // Act + var result = await _command.Run(_user, returnUrl); + + // Assert + Assert.True(result.IsT2); + var conflict = result.AsT2; + Assert.Equal("Unable to create billing portal session. Please contact support for assistance.", conflict.Response); + + await _stripeAdapter.Received(1).GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()); + await _stripeAdapter.Received(1).CreateBillingPortalSessionAsync(Arg.Any()); + } + + [Fact] + public async Task Run_WhenSessionUrlIsNull_ReturnsConflict() + { + // Arrange + var returnUrl = "https://example.com/billing"; + var subscription = new Subscription { Id = _user.GatewaySubscriptionId, Status = SubscriptionStatus.Active }; + var session = new Session { Url = null }; + + _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + _stripeAdapter.CreateBillingPortalSessionAsync(Arg.Any()) + .Returns(session); + + // Act + var result = await _command.Run(_user, returnUrl); + + // Assert + Assert.True(result.IsT2); + var conflict = result.AsT2; + Assert.Equal("Unable to create billing portal session. Please contact support for assistance.", conflict.Response); + + await _stripeAdapter.Received(1).GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()); + await _stripeAdapter.Received(1).CreateBillingPortalSessionAsync(Arg.Any()); + } + + [Fact] + public async Task Run_WhenStripeThrowsException_ReturnsUnhandled() + { + // Arrange + var returnUrl = "https://example.com/billing"; + var subscription = new Subscription { Id = _user.GatewaySubscriptionId, Status = SubscriptionStatus.Active }; + var stripeException = new StripeException { StripeError = new StripeError { Code = "api_error" } }; + + _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + _stripeAdapter.CreateBillingPortalSessionAsync(Arg.Any()) + .Throws(stripeException); + + // Act + var result = await _command.Run(_user, returnUrl); + + // Assert + Assert.True(result.IsT3); + var unhandled = result.AsT3; + Assert.Equal(stripeException, unhandled.Exception); + + await _stripeAdapter.Received(1).GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()); + await _stripeAdapter.Received(1).CreateBillingPortalSessionAsync(Arg.Any()); + } + + [Fact] + public async Task Run_WithDifferentReturnUrls_UsesCorrectUrl() + { + // Arrange + var returnUrl1 = "https://example.com/billing"; + var returnUrl2 = "https://different.com/account"; + var session = new Session { Url = "https://billing.stripe.com/session/test123" }; + var subscription = new Subscription { Id = _user.GatewaySubscriptionId, Status = SubscriptionStatus.Active }; + + _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + _stripeAdapter.CreateBillingPortalSessionAsync(Arg.Any()) + .Returns(session); + + // Act + var result1 = await _command.Run(_user, returnUrl1); + var result2 = await _command.Run(_user, returnUrl2); + + // Assert + Assert.True(result1.IsT0); + Assert.True(result2.IsT0); + + await _stripeAdapter.Received(1).CreateBillingPortalSessionAsync( + Arg.Is(o => o.ReturnUrl == returnUrl1)); + await _stripeAdapter.Received(1).CreateBillingPortalSessionAsync( + Arg.Is(o => o.ReturnUrl == returnUrl2)); + } + + [Fact] + public async Task Run_WithoutGatewaySubscriptionId_ReturnsBadRequest() + { + // Arrange + var userWithoutSubscriptionId = new User + { + Id = Guid.NewGuid(), + Email = "test@example.com", + GatewayCustomerId = "cus_test123", + GatewaySubscriptionId = null + }; + var returnUrl = "https://example.com/billing"; + + // Act + var result = await _command.Run(userWithoutSubscriptionId, returnUrl); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("User does not have a Premium subscription.", badRequest.Response); + + await _stripeAdapter.DidNotReceive().GetSubscriptionAsync(Arg.Any(), Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); + + _logger.Received(1).Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("does not have a subscription") && o.ToString()!.Contains(userWithoutSubscriptionId.Id.ToString())), + Arg.Any(), + Arg.Any>()); + } + + [Fact] + public async Task Run_WithActiveSubscription_ReturnsPortalUrl() + { + // Arrange + var returnUrl = "https://example.com/billing"; + var expectedUrl = "https://billing.stripe.com/session/test123"; + var session = new Session { Url = expectedUrl }; + var subscription = new Subscription { Id = _user.GatewaySubscriptionId, Status = SubscriptionStatus.Active }; + + _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + _stripeAdapter.CreateBillingPortalSessionAsync(Arg.Any()) + .Returns(session); + + // Act + var result = await _command.Run(_user, returnUrl); + + // Assert + Assert.True(result.IsT0); + Assert.Equal(expectedUrl, result.AsT0); + + await _stripeAdapter.Received(1).GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()); + } + + [Fact] + public async Task Run_WithPastDueSubscription_ReturnsPortalUrl() + { + // Arrange + var returnUrl = "https://example.com/billing"; + var expectedUrl = "https://billing.stripe.com/session/test456"; + var session = new Session { Url = expectedUrl }; + var subscription = new Subscription { Id = _user.GatewaySubscriptionId, Status = SubscriptionStatus.PastDue }; + + _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + _stripeAdapter.CreateBillingPortalSessionAsync(Arg.Any()) + .Returns(session); + + // Act + var result = await _command.Run(_user, returnUrl); + + // Assert + Assert.True(result.IsT0); + Assert.Equal(expectedUrl, result.AsT0); + + await _stripeAdapter.Received(1).GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()); + } + + [Fact] + public async Task Run_WithCanceledSubscription_ReturnsBadRequest() + { + // Arrange + var returnUrl = "https://example.com/billing"; + var subscription = new Subscription { Id = _user.GatewaySubscriptionId, Status = SubscriptionStatus.Canceled }; + + _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + + // Act + var result = await _command.Run(_user, returnUrl); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("Your subscription cannot be managed in its current status.", badRequest.Response); + + await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); + + _logger.Received(1).Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("not eligible for portal access") && o.ToString()!.Contains(_user.Id.ToString())), + Arg.Any(), + Arg.Any>()); + } + + [Fact] + public async Task Run_WithIncompleteSubscription_ReturnsBadRequest() + { + // Arrange + var returnUrl = "https://example.com/billing"; + var subscription = new Subscription { Id = _user.GatewaySubscriptionId, Status = SubscriptionStatus.Incomplete }; + + _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + + // Act + var result = await _command.Run(_user, returnUrl); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("Your subscription cannot be managed in its current status.", badRequest.Response); + + await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); + } + + [Fact] + public async Task Run_WhenSubscriptionFetchFails_ReturnsBadRequest() + { + // Arrange + var returnUrl = "https://example.com/billing"; + var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } }; + + _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) + .Throws(stripeException); + + // Act + var result = await _command.Run(_user, returnUrl); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("Unable to verify subscription status.", badRequest.Response); + + await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); + + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Failed to fetch subscription") && o.ToString()!.Contains(_user.Id.ToString())), + stripeException, + Arg.Any>()); + } + + [Fact] + public async Task Run_WhenSubscriptionIsNull_ReturnsBadRequest() + { + // Arrange + var returnUrl = "https://example.com/billing"; + + _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) + .Returns((Subscription?)null); + + // Act + var result = await _command.Run(_user, returnUrl); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("User subscription not found.", badRequest.Response); + + await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); + + _logger.Received(1).Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("was not found") && o.ToString()!.Contains(_user.Id.ToString())), + Arg.Any(), + Arg.Any>()); + } +} From 618d79f7946b7ea5a7de58efea96caa082f559fc Mon Sep 17 00:00:00 2001 From: Cy Okeke Date: Mon, 16 Mar 2026 12:36:51 +0100 Subject: [PATCH 2/8] Remove comment --- src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs b/src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs index 8ec1a47dc375..97deb9957b61 100644 --- a/src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs +++ b/src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs @@ -7,7 +7,6 @@ public class PortalSessionResponse { /// /// The URL to redirect the user to for accessing the Stripe billing portal. - /// This URL is time-limited and single-use. /// public required string Url { get; init; } } From e724b2bfbd079904289a4af76ccd3315d1ddb05d Mon Sep 17 00:00:00 2001 From: Cy Okeke Date: Mon, 16 Mar 2026 12:55:13 +0100 Subject: [PATCH 3/8] formatting issues have been resolved --- src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs | 2 +- .../Billing/Models/Responses/Portal/PortalSessionResponse.cs | 2 +- .../Portal/Commands/CreateBillingPortalSessionCommand.cs | 2 +- .../Billing/Models/Requests/Portal/PortalSessionRequestTests.cs | 2 +- .../Portal/Commands/CreateBillingPortalSessionCommandTests.cs | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs b/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs index e65ce5c03bda..1ea61747b442 100644 --- a/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs +++ b/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs @@ -1,4 +1,4 @@ -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; namespace Bit.Api.Billing.Models.Requests.Portal; diff --git a/src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs b/src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs index 97deb9957b61..752a295a4349 100644 --- a/src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs +++ b/src/Api/Billing/Models/Responses/Portal/PortalSessionResponse.cs @@ -1,4 +1,4 @@ -namespace Bit.Api.Billing.Models.Responses.Portal; +namespace Bit.Api.Billing.Models.Responses.Portal; /// /// Response model containing the Stripe billing portal session URL. diff --git a/src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs b/src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs index 90830708527d..96158b9734b6 100644 --- a/src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs +++ b/src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs @@ -1,4 +1,4 @@ -using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Services; using Bit.Core.Entities; diff --git a/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs b/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs index 29f772268af5..39518378a3bc 100644 --- a/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs +++ b/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs @@ -1,5 +1,5 @@ +using System.ComponentModel.DataAnnotations; using Bit.Api.Billing.Models.Requests.Portal; -using System.ComponentModel.DataAnnotations; using Xunit; namespace Bit.Api.Test.Billing.Models.Requests.Portal; diff --git a/test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs b/test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs index 4dd1f8e51ace..1bb2368a12fb 100644 --- a/test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs +++ b/test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs @@ -1,4 +1,4 @@ -using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Constants; using Bit.Core.Billing.Portal.Commands; using Bit.Core.Billing.Services; using Bit.Core.Entities; From 45dc27328109d43e0a14de9d2958ce1e769f2c5c Mon Sep 17 00:00:00 2001 From: Cy Okeke Date: Mon, 16 Mar 2026 19:17:10 +0100 Subject: [PATCH 4/8] Allow deep linking url --- .../Models/Requests/Portal/PortalSessionRequest.cs | 8 ++++---- .../Models/Requests/Portal/PortalSessionRequestTests.cs | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs b/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs index 1ea61747b442..0d8b6afe4751 100644 --- a/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs +++ b/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs @@ -9,7 +9,7 @@ public class PortalSessionRequest : IValidatableObject { /// /// The URL to redirect to after the user completes their session in the billing portal. - /// Must be a valid HTTP(S) URL. + /// Must be a valid HTTP(S) or bitwarden:// URL. /// [Required] [MaxLength(2000)] @@ -27,11 +27,11 @@ public IEnumerable Validate(ValidationContext validationContex yield break; } - // Prevent open redirect vulnerabilities by restricting to HTTP(S) schemes - if (uri.Scheme != Uri.UriSchemeHttp && uri.Scheme != Uri.UriSchemeHttps) + // Prevent open redirect vulnerabilities by restricting to HTTP(S) and bitwarden:// schemes + if (uri.Scheme != Uri.UriSchemeHttp && uri.Scheme != Uri.UriSchemeHttps && uri.Scheme != "bitwarden") { yield return new ValidationResult( - "Return URL must use HTTP or HTTPS scheme.", + "Return URL must use HTTP, HTTPS, or bitwarden:// scheme.", [nameof(ReturnUrl)]); } } diff --git a/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs b/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs index 39518378a3bc..971e20e33e22 100644 --- a/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs +++ b/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs @@ -10,6 +10,8 @@ public class PortalSessionRequestTests [InlineData("https://example.com/return")] [InlineData("http://localhost:3000/billing")] [InlineData("https://app.bitwarden.com/settings/billing")] + [InlineData("bitwarden://foo")] + [InlineData("bitwarden://vault/item/12345")] public void Validate_ValidHttpsUrl_ReturnsNoErrors(string returnUrl) { // Arrange @@ -60,7 +62,7 @@ public void Validate_NonHttpScheme_ReturnsSchemeError(string returnUrl) // Assert Assert.False(isValid); - Assert.Contains(results, r => r.ErrorMessage!.Contains("HTTP or HTTPS")); + Assert.Contains(results, r => r.ErrorMessage!.Contains("HTTP, HTTPS, or bitwarden://")); } [Theory] From b9c9bed2f15ceb3ca935dc73ab47b99c7ca58800 Mon Sep 17 00:00:00 2001 From: Cy Okeke Date: Tue, 17 Mar 2026 15:59:39 +0100 Subject: [PATCH 5/8] remove thr return url request --- .../VNext/AccountBillingVNextController.cs | 23 ++-- .../Requests/Portal/PortalSessionRequest.cs | 39 ------ .../AccountBillingVNextControllerTests.cs | 106 ++++++++------- .../Portal/PortalSessionRequestTests.cs | 121 ------------------ 4 files changed, 78 insertions(+), 211 deletions(-) delete mode 100644 src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs delete mode 100644 test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs diff --git a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs index 17134cb171bf..40bf7697a22e 100644 --- a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs +++ b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs @@ -1,6 +1,5 @@ using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Models.Requests.Payment; -using Bit.Api.Billing.Models.Requests.Portal; using Bit.Api.Billing.Models.Requests.Premium; using Bit.Api.Billing.Models.Requests.Storage; using Bit.Api.Billing.Models.Responses.Portal; @@ -12,7 +11,10 @@ using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Subscriptions.Commands; using Bit.Core.Billing.Subscriptions.Queries; +using Bit.Core.Context; using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -27,15 +29,17 @@ public class AccountBillingVNextController( ICreateBillingPortalSessionCommand createBillingPortalSessionCommand, ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand, ICreatePremiumCloudHostedSubscriptionCommand createPremiumCloudHostedSubscriptionCommand, + ICurrentContext currentContext, + IGetApplicableDiscountsQuery getApplicableDiscountsQuery, IGetBitwardenSubscriptionQuery getBitwardenSubscriptionQuery, IGetCreditQuery getCreditQuery, IGetPaymentMethodQuery getPaymentMethodQuery, IGetUserLicenseQuery getUserLicenseQuery, + GlobalSettings globalSettings, IReinstateSubscriptionCommand reinstateSubscriptionCommand, IUpdatePaymentMethodCommand updatePaymentMethodCommand, IUpdatePremiumStorageCommand updatePremiumStorageCommand, - IUpgradePremiumToOrganizationCommand upgradePremiumToOrganizationCommand, - IGetApplicableDiscountsQuery getApplicableDiscountsQuery) : BaseBillingController + IUpgradePremiumToOrganizationCommand upgradePremiumToOrganizationCommand) : BaseBillingController { [HttpGet("credit")] [InjectUser] @@ -154,17 +158,20 @@ public async Task GetApplicableDiscountsAsync( /// /// Creates a Stripe billing portal session for the authenticated user. /// The portal allows users to manage their subscription, payment methods, and billing history. + /// The return URL is automatically determined based on the client type (mobile, web, desktop, etc.). /// /// The authenticated user - /// Portal session configuration including return URL /// Portal session URL for redirection [HttpPost("portal-session")] [InjectUser] - public async Task CreatePortalSessionAsync( - [BindNever] User user, - [FromBody] PortalSessionRequest request) + public async Task CreatePortalSessionAsync([BindNever] User user) { - var result = await createBillingPortalSessionCommand.Run(user, request.ReturnUrl!); + // Mobile clients use deep link callbacks, all others redirect to web vault + var returnUrl = DeviceTypes.ToClientType(currentContext.DeviceType) == ClientType.Mobile + ? "bitwarden://premium-upgrade-callback" + : $"{globalSettings.BaseServiceUri.Vault}/#/settings/subscription/premium"; + + var result = await createBillingPortalSessionCommand.Run(user, returnUrl); return Handle(result.Map(url => new PortalSessionResponse { Url = url })); } diff --git a/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs b/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs deleted file mode 100644 index 0d8b6afe4751..000000000000 --- a/src/Api/Billing/Models/Requests/Portal/PortalSessionRequest.cs +++ /dev/null @@ -1,39 +0,0 @@ -using System.ComponentModel.DataAnnotations; - -namespace Bit.Api.Billing.Models.Requests.Portal; - -/// -/// Request model for creating a Stripe billing portal session. -/// -public class PortalSessionRequest : IValidatableObject -{ - /// - /// The URL to redirect to after the user completes their session in the billing portal. - /// Must be a valid HTTP(S) or bitwarden:// URL. - /// - [Required] - [MaxLength(2000)] - public string? ReturnUrl { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (!string.IsNullOrWhiteSpace(ReturnUrl)) - { - if (!Uri.TryCreate(ReturnUrl, UriKind.Absolute, out var uri)) - { - yield return new ValidationResult( - "Return URL must be a valid absolute URL.", - [nameof(ReturnUrl)]); - yield break; - } - - // Prevent open redirect vulnerabilities by restricting to HTTP(S) and bitwarden:// schemes - if (uri.Scheme != Uri.UriSchemeHttp && uri.Scheme != Uri.UriSchemeHttps && uri.Scheme != "bitwarden") - { - yield return new ValidationResult( - "Return URL must use HTTP, HTTPS, or bitwarden:// scheme.", - [nameof(ReturnUrl)]); - } - } - } -} diff --git a/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs b/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs index 67c01ca3a3da..f4275afcfe4c 100644 --- a/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs @@ -1,5 +1,4 @@ using Bit.Api.Billing.Controllers.VNext; -using Bit.Api.Billing.Models.Requests.Portal; using Bit.Api.Billing.Models.Requests.Storage; using Bit.Core.Billing.Commands; using Bit.Core.Billing.Licenses.Queries; @@ -10,7 +9,10 @@ using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Subscriptions.Commands; using Bit.Core.Billing.Subscriptions.Queries; +using Bit.Core.Context; using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Settings; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.HttpResults; @@ -30,6 +32,8 @@ public class AccountBillingVNextControllerTests private readonly IUpgradePremiumToOrganizationCommand _upgradePremiumToOrganizationCommand; private readonly IGetApplicableDiscountsQuery _getApplicableDiscountsQuery; private readonly ICreateBillingPortalSessionCommand _createBillingPortalSessionCommand; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; private readonly AccountBillingVNextController _sut; public AccountBillingVNextControllerTests() @@ -39,20 +43,28 @@ public AccountBillingVNextControllerTests() _upgradePremiumToOrganizationCommand = Substitute.For(); _getApplicableDiscountsQuery = Substitute.For(); _createBillingPortalSessionCommand = Substitute.For(); + _currentContext = Substitute.For(); + _globalSettings = new GlobalSettings + { + BaseServiceUri = new GlobalSettings.BaseServiceUriSettings(new GlobalSettings()) + }; + _globalSettings.BaseServiceUri.Vault = "https://vault.bitwarden.com"; _sut = new AccountBillingVNextController( _createBillingPortalSessionCommand, Substitute.For(), Substitute.For(), + _currentContext, + _getApplicableDiscountsQuery, Substitute.For(), Substitute.For(), Substitute.For(), _getUserLicenseQuery, + _globalSettings, Substitute.For(), Substitute.For(), _updatePremiumStorageCommand, - _upgradePremiumToOrganizationCommand, - _getApplicableDiscountsQuery); + _upgradePremiumToOrganizationCommand); } [Theory, BitAutoData] @@ -300,140 +312,148 @@ public async Task GetApplicableDiscountsAsync_EligibleDiscounts_ReturnsOkWithDis } [Theory, BitAutoData] - public async Task CreatePortalSessionAsync_Success_ReturnsPortalUrlAsync(User user, string returnUrl) + public async Task CreatePortalSessionAsync_Success_ReturnsPortalUrlAsync(User user) { // Arrange - var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; var portalUrl = "https://billing.stripe.com/session/test123"; + var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; - _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + _currentContext.DeviceType.Returns(DeviceType.ChromeBrowser); + _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(portalUrl)); // Act - var result = await _sut.CreatePortalSessionAsync(user, request); + var result = await _sut.CreatePortalSessionAsync(user); // Assert Assert.IsAssignableFrom(result); - await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + await _createBillingPortalSessionCommand.Received(1).Run(user, expectedReturnUrl); } [Theory, BitAutoData] - public async Task CreatePortalSessionAsync_NoCustomerId_ReturnsBadRequestAsync(User user, string returnUrl) + public async Task CreatePortalSessionAsync_NoCustomerId_ReturnsBadRequestAsync(User user) { // Arrange - var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + var expectedReturnUrl = "bitwarden://premium-upgrade-callback"; - _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + _currentContext.DeviceType.Returns(DeviceType.Android); + _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(new BadRequest("User does not have a Stripe customer ID."))); // Act - var result = await _sut.CreatePortalSessionAsync(user, request); + var result = await _sut.CreatePortalSessionAsync(user); // Assert Assert.IsAssignableFrom(result); - await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + await _createBillingPortalSessionCommand.Received(1).Run(user, expectedReturnUrl); } [Theory, BitAutoData] - public async Task CreatePortalSessionAsync_NoSubscriptionId_ReturnsBadRequestAsync(User user, string returnUrl) + public async Task CreatePortalSessionAsync_NoSubscriptionId_ReturnsBadRequestAsync(User user) { // Arrange - var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + var expectedReturnUrl = "bitwarden://premium-upgrade-callback"; - _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + _currentContext.DeviceType.Returns(DeviceType.iOS); + _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(new BadRequest("User does not have a Premium subscription."))); // Act - var result = await _sut.CreatePortalSessionAsync(user, request); + var result = await _sut.CreatePortalSessionAsync(user); // Assert Assert.IsAssignableFrom(result); - await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + await _createBillingPortalSessionCommand.Received(1).Run(user, expectedReturnUrl); } [Theory, BitAutoData] - public async Task CreatePortalSessionAsync_InvalidSubscriptionStatus_ReturnsBadRequestAsync(User user, string returnUrl) + public async Task CreatePortalSessionAsync_InvalidSubscriptionStatus_ReturnsBadRequestAsync(User user) { // Arrange - var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; - _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + _currentContext.DeviceType.Returns((DeviceType?)null); + _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(new BadRequest("Your subscription cannot be managed in its current status."))); // Act - var result = await _sut.CreatePortalSessionAsync(user, request); + var result = await _sut.CreatePortalSessionAsync(user); // Assert Assert.IsAssignableFrom(result); - await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + await _createBillingPortalSessionCommand.Received(1).Run(user, expectedReturnUrl); } [Theory, BitAutoData] - public async Task CreatePortalSessionAsync_SubscriptionNotFound_ReturnsBadRequestAsync(User user, string returnUrl) + public async Task CreatePortalSessionAsync_SubscriptionNotFound_ReturnsBadRequestAsync(User user) { // Arrange - var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; - _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + _currentContext.DeviceType.Returns(DeviceType.WindowsDesktop); + _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(new BadRequest("User subscription not found."))); // Act - var result = await _sut.CreatePortalSessionAsync(user, request); + var result = await _sut.CreatePortalSessionAsync(user); // Assert Assert.IsAssignableFrom(result); - await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + await _createBillingPortalSessionCommand.Received(1).Run(user, expectedReturnUrl); } [Theory, BitAutoData] - public async Task CreatePortalSessionAsync_StripeException_ReturnsServerErrorAsync(User user, string returnUrl) + public async Task CreatePortalSessionAsync_StripeException_ReturnsServerErrorAsync(User user) { // Arrange - var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; var exception = new StripeException("Stripe API error"); - _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + _currentContext.DeviceType.Returns(DeviceType.MacOsDesktop); + _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(new Unhandled(exception))); // Act - var result = await _sut.CreatePortalSessionAsync(user, request); + var result = await _sut.CreatePortalSessionAsync(user); // Assert Assert.IsAssignableFrom(result); - await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + await _createBillingPortalSessionCommand.Received(1).Run(user, expectedReturnUrl); } [Theory, BitAutoData] - public async Task CreatePortalSessionAsync_SessionWithNullUrl_ReturnsServerErrorAsync(User user, string returnUrl) + public async Task CreatePortalSessionAsync_SessionWithNullUrl_ReturnsServerErrorAsync(User user) { // Arrange - var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; - _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + _currentContext.DeviceType.Returns(DeviceType.ChromeExtension); + _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(new Conflict("Unable to create billing portal session. Please contact support for assistance."))); // Act - var result = await _sut.CreatePortalSessionAsync(user, request); + var result = await _sut.CreatePortalSessionAsync(user); // Assert Assert.IsAssignableFrom(result); - await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + await _createBillingPortalSessionCommand.Received(1).Run(user, expectedReturnUrl); } [Theory, BitAutoData] - public async Task CreatePortalSessionAsync_NullSession_ReturnsServerErrorAsync(User user, string returnUrl) + public async Task CreatePortalSessionAsync_NullSession_ReturnsServerErrorAsync(User user) { // Arrange - var request = new PortalSessionRequest { ReturnUrl = $"https://example.com/{returnUrl}" }; + var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; - _createBillingPortalSessionCommand.Run(user, request.ReturnUrl) + _currentContext.DeviceType.Returns(DeviceType.LinuxDesktop); + _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(new Conflict("Unable to create billing portal session. Please contact support for assistance."))); // Act - var result = await _sut.CreatePortalSessionAsync(user, request); + var result = await _sut.CreatePortalSessionAsync(user); // Assert Assert.IsAssignableFrom(result); - await _createBillingPortalSessionCommand.Received(1).Run(user, request.ReturnUrl); + await _createBillingPortalSessionCommand.Received(1).Run(user, expectedReturnUrl); } } diff --git a/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs b/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs deleted file mode 100644 index 971e20e33e22..000000000000 --- a/test/Api.Test/Billing/Models/Requests/Portal/PortalSessionRequestTests.cs +++ /dev/null @@ -1,121 +0,0 @@ -using System.ComponentModel.DataAnnotations; -using Bit.Api.Billing.Models.Requests.Portal; -using Xunit; - -namespace Bit.Api.Test.Billing.Models.Requests.Portal; - -public class PortalSessionRequestTests -{ - [Theory] - [InlineData("https://example.com/return")] - [InlineData("http://localhost:3000/billing")] - [InlineData("https://app.bitwarden.com/settings/billing")] - [InlineData("bitwarden://foo")] - [InlineData("bitwarden://vault/item/12345")] - public void Validate_ValidHttpsUrl_ReturnsNoErrors(string returnUrl) - { - // Arrange - var request = new PortalSessionRequest { ReturnUrl = returnUrl }; - var context = new ValidationContext(request); - var results = new List(); - - // Act - var isValid = Validator.TryValidateObject(request, context, results, true); - - // Assert - Assert.True(isValid); - Assert.Empty(results); - } - - [Theory] - [InlineData(null)] - [InlineData("")] - public void Validate_MissingReturnUrl_ReturnsRequiredError(string? returnUrl) - { - // Arrange - var request = new PortalSessionRequest { ReturnUrl = returnUrl }; - var context = new ValidationContext(request); - var results = new List(); - - // Act - var isValid = Validator.TryValidateObject(request, context, results, true); - - // Assert - Assert.False(isValid); - Assert.Contains(results, r => r.MemberNames.Contains(nameof(PortalSessionRequest.ReturnUrl))); - } - - [Theory] - [InlineData("javascript:alert('xss')")] - [InlineData("data:text/html,")] - [InlineData("file:///etc/passwd")] - [InlineData("ftp://example.com/file")] - public void Validate_NonHttpScheme_ReturnsSchemeError(string returnUrl) - { - // Arrange - var request = new PortalSessionRequest { ReturnUrl = returnUrl }; - var context = new ValidationContext(request); - var results = new List(); - - // Act - var isValid = Validator.TryValidateObject(request, context, results, true); - - // Assert - Assert.False(isValid); - Assert.Contains(results, r => r.ErrorMessage!.Contains("HTTP, HTTPS, or bitwarden://")); - } - - [Theory] - [InlineData("not-a-url")] - [InlineData("://invalid")] - [InlineData("http://")] - public void Validate_InvalidUrl_ReturnsUrlFormatError(string returnUrl) - { - // Arrange - var request = new PortalSessionRequest { ReturnUrl = returnUrl }; - var context = new ValidationContext(request); - var results = new List(); - - // Act - var isValid = Validator.TryValidateObject(request, context, results, true); - - // Assert - Assert.False(isValid); - Assert.NotEmpty(results); - } - - [Fact] - public void Validate_ExcessivelyLongUrl_ReturnsMaxLengthError() - { - // Arrange - var longUrl = "https://example.com/" + new string('a', 2500); - var request = new PortalSessionRequest { ReturnUrl = longUrl }; - var context = new ValidationContext(request); - var results = new List(); - - // Act - var isValid = Validator.TryValidateObject(request, context, results, true); - - // Assert - Assert.False(isValid); - Assert.Contains(results, r => r.MemberNames.Contains(nameof(PortalSessionRequest.ReturnUrl))); - } - - [Theory] - [InlineData(" ")] - [InlineData("\t\n")] - public void Validate_WhitespaceOnlyUrl_ReturnsRequiredError(string returnUrl) - { - // Arrange - var request = new PortalSessionRequest { ReturnUrl = returnUrl }; - var context = new ValidationContext(request); - var results = new List(); - - // Act - var isValid = Validator.TryValidateObject(request, context, results, true); - - // Assert - Assert.False(isValid); - Assert.NotEmpty(results); - } -} From 5e7bf7893f3e32b42f6e5427afd65873be174a0b Mon Sep 17 00:00:00 2001 From: Cy Okeke Date: Wed, 18 Mar 2026 17:06:05 +0100 Subject: [PATCH 6/8] Resolve review comments around comments --- .../VNext/AccountBillingVNextController.cs | 19 ++++++----------- .../CreateBillingPortalSessionCommand.cs | 21 +++---------------- 2 files changed, 9 insertions(+), 31 deletions(-) diff --git a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs index 40bf7697a22e..f357cb13d3c9 100644 --- a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs +++ b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs @@ -14,7 +14,6 @@ using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -35,7 +34,6 @@ public class AccountBillingVNextController( IGetCreditQuery getCreditQuery, IGetPaymentMethodQuery getPaymentMethodQuery, IGetUserLicenseQuery getUserLicenseQuery, - GlobalSettings globalSettings, IReinstateSubscriptionCommand reinstateSubscriptionCommand, IUpdatePaymentMethodCommand updatePaymentMethodCommand, IUpdatePremiumStorageCommand updatePremiumStorageCommand, @@ -155,21 +153,16 @@ public async Task GetApplicableDiscountsAsync( return Handle(result); } - /// - /// Creates a Stripe billing portal session for the authenticated user. - /// The portal allows users to manage their subscription, payment methods, and billing history. - /// The return URL is automatically determined based on the client type (mobile, web, desktop, etc.). - /// - /// The authenticated user - /// Portal session URL for redirection [HttpPost("portal-session")] [InjectUser] public async Task CreatePortalSessionAsync([BindNever] User user) { - // Mobile clients use deep link callbacks, all others redirect to web vault - var returnUrl = DeviceTypes.ToClientType(currentContext.DeviceType) == ClientType.Mobile - ? "bitwarden://premium-upgrade-callback" - : $"{globalSettings.BaseServiceUri.Vault}/#/settings/subscription/premium"; + if (DeviceTypes.ToClientType(currentContext.DeviceType) != ClientType.Mobile) + { + return TypedResults.NotFound(); + } + + var returnUrl = "bitwarden://premium-upgrade-callback"; var result = await createBillingPortalSessionCommand.Run(user, returnUrl); return Handle(result.Map(url => new PortalSessionResponse { Url = url })); diff --git a/src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs b/src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs index 96158b9734b6..5a6b46a6dd27 100644 --- a/src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs +++ b/src/Core/Billing/Portal/Commands/CreateBillingPortalSessionCommand.cs @@ -32,14 +32,14 @@ public Task> Run(User user, string returnUrl) => { _logger.LogWarning("{Command}: User ({UserId}) does not have a Stripe customer ID", CommandName, user.Id); - return new BadRequest("User does not have a Stripe customer ID."); + return DefaultConflict; } if (string.IsNullOrEmpty(user.GatewaySubscriptionId)) { _logger.LogWarning("{Command}: User ({UserId}) does not have a subscription", CommandName, user.Id); - return new BadRequest("User does not have a Premium subscription."); + return DefaultConflict; } // Fetch the subscription to validate its status @@ -53,14 +53,7 @@ public Task> Run(User user, string returnUrl) => _logger.LogError(stripeException, "{Command}: Failed to fetch subscription ({SubscriptionId}) for user ({UserId})", CommandName, user.GatewaySubscriptionId, user.Id); - return new BadRequest("Unable to verify subscription status."); - } - - if (subscription == null) - { - _logger.LogWarning("{Command}: Subscription ({SubscriptionId}) for user ({UserId}) was not found", - CommandName, user.GatewaySubscriptionId, user.Id); - return new BadRequest("User subscription not found."); + return DefaultConflict; } // Only allow portal access for active or past_due subscriptions @@ -80,14 +73,6 @@ public Task> Run(User user, string returnUrl) => var session = await stripeAdapter.CreateBillingPortalSessionAsync(options); - if (session?.Url == null) - { - return DefaultConflict; - } - - _logger.LogInformation("{Command}: Successfully created billing portal session for user ({UserId})", - CommandName, user.Id); - return session.Url; }); } From 1b76d59caaca31f18ae86e43a6d2e961c16dd517 Mon Sep 17 00:00:00 2001 From: Cy Okeke Date: Wed, 18 Mar 2026 17:17:31 +0100 Subject: [PATCH 7/8] Fix the failing test after removing _globalSettings --- .../AccountBillingVNextControllerTests.cs | 61 +++++++++++-------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs b/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs index f4275afcfe4c..92ba431d600f 100644 --- a/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/VNext/AccountBillingVNextControllerTests.cs @@ -12,7 +12,6 @@ using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Settings; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.HttpResults; @@ -22,6 +21,7 @@ using Xunit; using BadRequest = Bit.Core.Billing.Commands.BadRequest; using Conflict = Bit.Core.Billing.Commands.Conflict; +using NotFound = Microsoft.AspNetCore.Http.HttpResults.NotFound; namespace Bit.Api.Test.Billing.Controllers.VNext; @@ -33,7 +33,6 @@ public class AccountBillingVNextControllerTests private readonly IGetApplicableDiscountsQuery _getApplicableDiscountsQuery; private readonly ICreateBillingPortalSessionCommand _createBillingPortalSessionCommand; private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; private readonly AccountBillingVNextController _sut; public AccountBillingVNextControllerTests() @@ -44,11 +43,6 @@ public AccountBillingVNextControllerTests() _getApplicableDiscountsQuery = Substitute.For(); _createBillingPortalSessionCommand = Substitute.For(); _currentContext = Substitute.For(); - _globalSettings = new GlobalSettings - { - BaseServiceUri = new GlobalSettings.BaseServiceUriSettings(new GlobalSettings()) - }; - _globalSettings.BaseServiceUri.Vault = "https://vault.bitwarden.com"; _sut = new AccountBillingVNextController( _createBillingPortalSessionCommand, @@ -60,7 +54,6 @@ public AccountBillingVNextControllerTests() Substitute.For(), Substitute.For(), _getUserLicenseQuery, - _globalSettings, Substitute.For(), Substitute.For(), _updatePremiumStorageCommand, @@ -316,9 +309,9 @@ public async Task CreatePortalSessionAsync_Success_ReturnsPortalUrlAsync(User us { // Arrange var portalUrl = "https://billing.stripe.com/session/test123"; - var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; + var expectedReturnUrl = "bitwarden://premium-upgrade-callback"; - _currentContext.DeviceType.Returns(DeviceType.ChromeBrowser); + _currentContext.DeviceType.Returns(DeviceType.Android); _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(portalUrl)); @@ -331,14 +324,14 @@ public async Task CreatePortalSessionAsync_Success_ReturnsPortalUrlAsync(User us } [Theory, BitAutoData] - public async Task CreatePortalSessionAsync_NoCustomerId_ReturnsBadRequestAsync(User user) + public async Task CreatePortalSessionAsync_NoCustomerId_ReturnsConflictAsync(User user) { // Arrange var expectedReturnUrl = "bitwarden://premium-upgrade-callback"; - _currentContext.DeviceType.Returns(DeviceType.Android); + _currentContext.DeviceType.Returns(DeviceType.AndroidAmazon); _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) - .Returns(new BillingCommandResult(new BadRequest("User does not have a Stripe customer ID."))); + .Returns(new BillingCommandResult(new Conflict("Unable to create billing portal session. Please contact support for assistance."))); // Act var result = await _sut.CreatePortalSessionAsync(user); @@ -349,14 +342,14 @@ public async Task CreatePortalSessionAsync_NoCustomerId_ReturnsBadRequestAsync(U } [Theory, BitAutoData] - public async Task CreatePortalSessionAsync_NoSubscriptionId_ReturnsBadRequestAsync(User user) + public async Task CreatePortalSessionAsync_NoSubscriptionId_ReturnsConflictAsync(User user) { // Arrange var expectedReturnUrl = "bitwarden://premium-upgrade-callback"; _currentContext.DeviceType.Returns(DeviceType.iOS); _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) - .Returns(new BillingCommandResult(new BadRequest("User does not have a Premium subscription."))); + .Returns(new BillingCommandResult(new Conflict("Unable to create billing portal session. Please contact support for assistance."))); // Act var result = await _sut.CreatePortalSessionAsync(user); @@ -370,9 +363,9 @@ public async Task CreatePortalSessionAsync_NoSubscriptionId_ReturnsBadRequestAsy public async Task CreatePortalSessionAsync_InvalidSubscriptionStatus_ReturnsBadRequestAsync(User user) { // Arrange - var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; + var expectedReturnUrl = "bitwarden://premium-upgrade-callback"; - _currentContext.DeviceType.Returns((DeviceType?)null); + _currentContext.DeviceType.Returns(DeviceType.iOS); _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(new BadRequest("Your subscription cannot be managed in its current status."))); @@ -385,14 +378,14 @@ public async Task CreatePortalSessionAsync_InvalidSubscriptionStatus_ReturnsBadR } [Theory, BitAutoData] - public async Task CreatePortalSessionAsync_SubscriptionNotFound_ReturnsBadRequestAsync(User user) + public async Task CreatePortalSessionAsync_SubscriptionNotFound_ReturnsConflictAsync(User user) { // Arrange - var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; + var expectedReturnUrl = "bitwarden://premium-upgrade-callback"; - _currentContext.DeviceType.Returns(DeviceType.WindowsDesktop); + _currentContext.DeviceType.Returns(DeviceType.Android); _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) - .Returns(new BillingCommandResult(new BadRequest("User subscription not found."))); + .Returns(new BillingCommandResult(new Conflict("Unable to create billing portal session. Please contact support for assistance."))); // Act var result = await _sut.CreatePortalSessionAsync(user); @@ -406,10 +399,10 @@ public async Task CreatePortalSessionAsync_SubscriptionNotFound_ReturnsBadReques public async Task CreatePortalSessionAsync_StripeException_ReturnsServerErrorAsync(User user) { // Arrange - var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; + var expectedReturnUrl = "bitwarden://premium-upgrade-callback"; var exception = new StripeException("Stripe API error"); - _currentContext.DeviceType.Returns(DeviceType.MacOsDesktop); + _currentContext.DeviceType.Returns(DeviceType.iOS); _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(new Unhandled(exception))); @@ -425,9 +418,9 @@ public async Task CreatePortalSessionAsync_StripeException_ReturnsServerErrorAsy public async Task CreatePortalSessionAsync_SessionWithNullUrl_ReturnsServerErrorAsync(User user) { // Arrange - var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; + var expectedReturnUrl = "bitwarden://premium-upgrade-callback"; - _currentContext.DeviceType.Returns(DeviceType.ChromeExtension); + _currentContext.DeviceType.Returns(DeviceType.Android); _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(new Conflict("Unable to create billing portal session. Please contact support for assistance."))); @@ -443,9 +436,9 @@ public async Task CreatePortalSessionAsync_SessionWithNullUrl_ReturnsServerError public async Task CreatePortalSessionAsync_NullSession_ReturnsServerErrorAsync(User user) { // Arrange - var expectedReturnUrl = "https://vault.bitwarden.com/#/settings/subscription/premium"; + var expectedReturnUrl = "bitwarden://premium-upgrade-callback"; - _currentContext.DeviceType.Returns(DeviceType.LinuxDesktop); + _currentContext.DeviceType.Returns(DeviceType.iOS); _createBillingPortalSessionCommand.Run(user, expectedReturnUrl) .Returns(new BillingCommandResult(new Conflict("Unable to create billing portal session. Please contact support for assistance."))); @@ -456,4 +449,18 @@ public async Task CreatePortalSessionAsync_NullSession_ReturnsServerErrorAsync(U Assert.IsAssignableFrom(result); await _createBillingPortalSessionCommand.Received(1).Run(user, expectedReturnUrl); } + + [Theory, BitAutoData] + public async Task CreatePortalSessionAsync_NonMobileDevice_ReturnsNotFoundAsync(User user) + { + // Arrange + _currentContext.DeviceType.Returns(DeviceType.ChromeBrowser); + + // Act + var result = await _sut.CreatePortalSessionAsync(user); + + // Assert + Assert.IsType(result); + await _createBillingPortalSessionCommand.DidNotReceiveWithAnyArgs().Run(Arg.Any(), Arg.Any()); + } } From 81fbff393e715bc44e4302d225976d027f497d9d Mon Sep 17 00:00:00 2001 From: Cy Okeke Date: Wed, 18 Mar 2026 17:22:47 +0100 Subject: [PATCH 8/8] Fix the failing unit test --- .../CreateBillingPortalSessionCommandTests.cs | 114 +++--------------- 1 file changed, 16 insertions(+), 98 deletions(-) diff --git a/test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs b/test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs index 1bb2368a12fb..b16dc0e07865 100644 --- a/test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs +++ b/test/Core.Test/Billing/Portal/Commands/CreateBillingPortalSessionCommandTests.cs @@ -58,17 +58,10 @@ await _stripeAdapter.Received(1).CreateBillingPortalSessionAsync( Arg.Is(o => o.Customer == _user.GatewayCustomerId && o.ReturnUrl == returnUrl)); - - _logger.Received(1).Log( - LogLevel.Information, - Arg.Any(), - Arg.Is(o => o.ToString()!.Contains("Successfully created billing portal session") && o.ToString()!.Contains(_user.Id.ToString())), - Arg.Any(), - Arg.Any>()); } [Fact] - public async Task Run_WithoutGatewayCustomerId_ReturnsBadRequest() + public async Task Run_WithoutGatewayCustomerId_ReturnsConflict() { // Arrange var userWithoutCustomerId = new User @@ -83,9 +76,9 @@ public async Task Run_WithoutGatewayCustomerId_ReturnsBadRequest() var result = await _command.Run(userWithoutCustomerId, returnUrl); // Assert - Assert.True(result.IsT1); - var badRequest = result.AsT1; - Assert.Equal("User does not have a Stripe customer ID.", badRequest.Response); + Assert.True(result.IsT2); + var conflict = result.AsT2; + Assert.Equal("Unable to create billing portal session. Please contact support for assistance.", conflict.Response); await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); @@ -98,7 +91,7 @@ public async Task Run_WithoutGatewayCustomerId_ReturnsBadRequest() } [Fact] - public async Task Run_WithEmptyGatewayCustomerId_ReturnsBadRequest() + public async Task Run_WithEmptyGatewayCustomerId_ReturnsConflict() { // Arrange var userWithEmptyCustomerId = new User @@ -113,9 +106,9 @@ public async Task Run_WithEmptyGatewayCustomerId_ReturnsBadRequest() var result = await _command.Run(userWithEmptyCustomerId, returnUrl); // Assert - Assert.True(result.IsT1); - var badRequest = result.AsT1; - Assert.Equal("User does not have a Stripe customer ID.", badRequest.Response); + Assert.True(result.IsT2); + var conflict = result.AsT2; + Assert.Equal("Unable to create billing portal session. Please contact support for assistance.", conflict.Response); await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); @@ -127,55 +120,6 @@ public async Task Run_WithEmptyGatewayCustomerId_ReturnsBadRequest() Arg.Any>()); } - [Fact] - public async Task Run_WhenSessionIsNull_ReturnsConflict() - { - // Arrange - var returnUrl = "https://example.com/billing"; - var subscription = new Subscription { Id = _user.GatewaySubscriptionId, Status = SubscriptionStatus.Active }; - - _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) - .Returns(subscription); - _stripeAdapter.CreateBillingPortalSessionAsync(Arg.Any()) - .Returns((Session?)null); - - // Act - var result = await _command.Run(_user, returnUrl); - - // Assert - Assert.True(result.IsT2); - var conflict = result.AsT2; - Assert.Equal("Unable to create billing portal session. Please contact support for assistance.", conflict.Response); - - await _stripeAdapter.Received(1).GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()); - await _stripeAdapter.Received(1).CreateBillingPortalSessionAsync(Arg.Any()); - } - - [Fact] - public async Task Run_WhenSessionUrlIsNull_ReturnsConflict() - { - // Arrange - var returnUrl = "https://example.com/billing"; - var subscription = new Subscription { Id = _user.GatewaySubscriptionId, Status = SubscriptionStatus.Active }; - var session = new Session { Url = null }; - - _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) - .Returns(subscription); - _stripeAdapter.CreateBillingPortalSessionAsync(Arg.Any()) - .Returns(session); - - // Act - var result = await _command.Run(_user, returnUrl); - - // Assert - Assert.True(result.IsT2); - var conflict = result.AsT2; - Assert.Equal("Unable to create billing portal session. Please contact support for assistance.", conflict.Response); - - await _stripeAdapter.Received(1).GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()); - await _stripeAdapter.Received(1).CreateBillingPortalSessionAsync(Arg.Any()); - } - [Fact] public async Task Run_WhenStripeThrowsException_ReturnsUnhandled() { @@ -230,7 +174,7 @@ await _stripeAdapter.Received(1).CreateBillingPortalSessionAsync( } [Fact] - public async Task Run_WithoutGatewaySubscriptionId_ReturnsBadRequest() + public async Task Run_WithoutGatewaySubscriptionId_ReturnsConflict() { // Arrange var userWithoutSubscriptionId = new User @@ -246,9 +190,9 @@ public async Task Run_WithoutGatewaySubscriptionId_ReturnsBadRequest() var result = await _command.Run(userWithoutSubscriptionId, returnUrl); // Assert - Assert.True(result.IsT1); - var badRequest = result.AsT1; - Assert.Equal("User does not have a Premium subscription.", badRequest.Response); + Assert.True(result.IsT2); + var conflict = result.AsT2; + Assert.Equal("Unable to create billing portal session. Please contact support for assistance.", conflict.Response); await _stripeAdapter.DidNotReceive().GetSubscriptionAsync(Arg.Any(), Arg.Any()); await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); @@ -359,7 +303,7 @@ public async Task Run_WithIncompleteSubscription_ReturnsBadRequest() } [Fact] - public async Task Run_WhenSubscriptionFetchFails_ReturnsBadRequest() + public async Task Run_WhenSubscriptionFetchFails_ReturnsConflict() { // Arrange var returnUrl = "https://example.com/billing"; @@ -372,9 +316,9 @@ public async Task Run_WhenSubscriptionFetchFails_ReturnsBadRequest() var result = await _command.Run(_user, returnUrl); // Assert - Assert.True(result.IsT1); - var badRequest = result.AsT1; - Assert.Equal("Unable to verify subscription status.", badRequest.Response); + Assert.True(result.IsT2); + var conflict = result.AsT2; + Assert.Equal("Unable to create billing portal session. Please contact support for assistance.", conflict.Response); await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); @@ -386,30 +330,4 @@ public async Task Run_WhenSubscriptionFetchFails_ReturnsBadRequest() Arg.Any>()); } - [Fact] - public async Task Run_WhenSubscriptionIsNull_ReturnsBadRequest() - { - // Arrange - var returnUrl = "https://example.com/billing"; - - _stripeAdapter.GetSubscriptionAsync(_user.GatewaySubscriptionId, Arg.Any()) - .Returns((Subscription?)null); - - // Act - var result = await _command.Run(_user, returnUrl); - - // Assert - Assert.True(result.IsT1); - var badRequest = result.AsT1; - Assert.Equal("User subscription not found.", badRequest.Response); - - await _stripeAdapter.DidNotReceive().CreateBillingPortalSessionAsync(Arg.Any()); - - _logger.Received(1).Log( - LogLevel.Warning, - Arg.Any(), - Arg.Is(o => o.ToString()!.Contains("was not found") && o.ToString()!.Contains(_user.Id.ToString())), - Arg.Any(), - Arg.Any>()); - } }