Skip to content
2 changes: 1 addition & 1 deletion src/Api/Auth/Controllers/WebAuthnController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ private async Task ValidateIfUserCanUsePasskeyLogin(Guid userId)
return;
}

var requireSsoPolicyRequirement = await _policyRequirementQuery.GetAsync<RequireSsoPolicyRequirement>(userId);
var requireSsoPolicyRequirement = await _policyRequirementQuery.GetAsyncVNext<RequireSsoPolicyRequirement>(userId);

if (!requireSsoPolicyRequirement.CanUsePasskeyLogin)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ public interface IPolicyRequirementQuery
/// <typeparam name="T">The IPolicyRequirement that corresponds to the policy you want to enforce.</typeparam>
Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement;

/// <summary>
/// Get a policy requirement for a specific user using the optimized single-user query.
/// The policy requirement represents how one or more policy types should be enforced against the user.
/// It will always return a value even if there are no policies that should be enforced.
/// This is the vNext version that uses the optimized GetPolicyDetailsByUserIdAndPolicyTypeAsync method.
/// </summary>
/// <param name="userId">The user that you need to enforce the policy against.</param>
/// <typeparam name="T">The IPolicyRequirement that corresponds to the policy you want to enforce.</typeparam>
Task<T> GetAsyncVNext<T>(Guid userId) where T : IPolicyRequirement;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review: PR #7173 β€” Finalize RequireSsoPolicyRequirement

Summary: This PR adds an optimized single-user stored procedure for policy detail queries, wires it through a GetAsyncVNext method, and switches the two RequireSsoPolicyRequirement consumers to use it. The approach is sound β€” keeping the feature flag as a killswitch while optimizing the hot login path. Overall this is a clean, well-structured PR.

Findings (all non-blocking):

  1. CanUsePasskeyLogin logic change (RequireSsoPolicyRequirement.cs:51-54): The old All(Revoked||Invited) β†’ new !Any(Accepted||Confirmed) change is logically equivalent since Enforce already filters out Invited/Revoked. The new code is more robust by explicitly naming the disabling statuses. Good change.

  2. Dead mock setup in BaseRequestValidatorTests:551,590: These still mock GetAsync<RequireSsoPolicyRequirement> (not GetAsyncVNext), but since _ssoRequestValidator.ValidateAsync is also mocked, the mock is never hit. Pre-existing, not introduced by this PR. Cleanup candidate for follow-up.

  3. SQL sproc looks correct: CTE with UNION ALL for non-invited/invited users, @UserEmail IS NOT NULL guard, LEFT JOIN for providers, proper Enabled/UsePolicies filters. Migration matches sproc (only CREATE vs CREATE OR ALTER).

  4. EF implementation matches sproc logic: Handles email null-safety correctly (SQL NULL = NULL β†’ false), uses HashSet for provider lookup, consistent filter conditions.

  5. Integration tests are thorough: Covers confirmed/accepted/invited users, multiple orgs, policy type filtering, disabled policies/orgs, UsePolicies=false, and provider flag.

Verdict: Looks good. βœ…


/// <summary>
/// Get a policy requirement for a list of users.
/// The policy requirement represents how one or more policy types should be enforced against the users.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ public class PolicyRequirementQuery(
public async Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement
=> (await GetAsync<T>([userId])).Single().Requirement;

public async Task<T> GetAsyncVNext<T>(Guid userId) where T : IPolicyRequirement
{
var factory = factories.OfType<IPolicyRequirementFactory<T>>().SingleOrDefault();
if (factory is null)
{
throw new NotImplementedException("No Requirement Factory found for " + typeof(T));
}

var policyDetails = await policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync(userId, factory.PolicyType);
var enforcedPolicyDetails = policyDetails.Where(factory.Enforce);

return factory.Create(enforcedPolicyDetails);
}

public async Task<IEnumerable<(Guid UserId, T Requirement)>> GetAsync<T>(IEnumerable<Guid> userIds) where T : IPolicyRequirement
{
var factory = factories.OfType<IPolicyRequirementFactory<T>>().SingleOrDefault();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ public RequireSsoPolicyRequirementFactory(GlobalSettings globalSettings)

public override RequireSsoPolicyRequirement Create(IEnumerable<PolicyDetails> policyDetails)
{
policyDetails = policyDetails.ToList();
var result = new RequireSsoPolicyRequirement
{
CanUsePasskeyLogin = policyDetails.All(p =>
p.OrganizationUserStatus == OrganizationUserStatusType.Revoked ||
p.OrganizationUserStatus == OrganizationUserStatusType.Invited),
CanUsePasskeyLogin = !policyDetails.Any(p =>
p.OrganizationUserStatus is OrganizationUserStatusType.Accepted or OrganizationUserStatusType.Confirmed),

SsoRequired = policyDetails.Any(p =>
p.OrganizationUserStatus == OrganizationUserStatusType.Confirmed)
Expand Down
16 changes: 16 additions & 0 deletions src/Core/AdminConsole/Repositories/IPolicyRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,20 @@ public interface IPolicyRepository : IRepository<Policy, Guid>
/// associated with the specified users and policy type.
/// </returns>
Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetailsByUserIdsAndPolicyType(IEnumerable<Guid> userIds, PolicyType policyType);

/// <summary>
/// Retrieves policy details for a single user filtered by the specified policy type.
/// </summary>
/// <remarks>
/// Returns policy details only for enabled policies from enabled organizations that support policies.
/// This includes both confirmed users (matched by UserId) and invited users (matched by email).
/// Provider users are identified via the IsProvider flag.
/// </remarks>
/// <param name="userId">The user identifier for which policy details are to be fetched.</param>
/// <param name="policyType">The type of policy for which the details are required.</param>
/// <returns>
/// An asynchronous task that returns a collection of <see cref="PolicyDetails"/> objects containing
/// the policy information associated with the specified user and policy type.
/// </returns>
Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserIdAndPolicyTypeAsync(Guid userId, PolicyType policyType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private async Task<bool> RequireSsoAuthenticationAsync(User user, string grantTy

// Check if user belongs to any organization with an active SSO policy
var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements)
? (await _policyRequirementQuery.GetAsync<RequireSsoPolicyRequirement>(user.Id))
? (await _policyRequirementQuery.GetAsyncVNext<RequireSsoPolicyRequirement>(user.Id))
.SsoRequired
: await _policyService.AnyPoliciesApplicableToUserAsync(
user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,19 @@ public async Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetailsByOrga
return results.ToList();
}
}

public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserIdAndPolicyTypeAsync(Guid userId, PolicyType policyType)
{
await using var connection = new SqlConnection(ConnectionString);
var results = await connection.QueryAsync<PolicyDetails>(
$"[{Schema}].[PolicyDetails_ReadByUserIdPolicyType]",
new
{
UserId = userId,
PolicyType = (byte)policyType
},
commandType: CommandType.StoredProcedure);

return results.ToList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,53 @@ where p.Enabled

return allResults.ToList();
}

public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserIdAndPolicyTypeAsync(Guid userId, PolicyType policyType)
{
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);

// Get user email for invited user matching
var userEmail = await dbContext.Users
.Where(u => u.Id == userId)
.Select(u => u.Email)
.FirstOrDefaultAsync();

// Get provider relationships
var providerOrganizationIds = await (from pu in dbContext.ProviderUsers
join po in dbContext.ProviderOrganizations on pu.ProviderId equals po.ProviderId
where pu.UserId == userId
select po.OrganizationId)
.Distinct()
.ToListAsync();

var providerSet = new HashSet<Guid>(providerOrganizationIds);

// Get organization users (both confirmed/accepted and invited)
var orgUsersQuery = dbContext.OrganizationUsers
.Where(ou => (ou.Status != OrganizationUserStatusType.Invited && ou.UserId == userId) ||
(ou.Status == OrganizationUserStatusType.Invited && ou.Email == userEmail && userEmail != null));

// Join with policies and organizations
var query = from policy in dbContext.Policies
join orgUser in orgUsersQuery on policy.OrganizationId equals orgUser.OrganizationId
join org in dbContext.Organizations on policy.OrganizationId equals org.Id
where policy.Type == policyType
&& policy.Enabled
&& org.Enabled
&& org.UsePolicies
select new PolicyDetails
{
OrganizationUserId = orgUser.Id,
OrganizationId = policy.OrganizationId,
PolicyType = policy.Type,
PolicyData = policy.Data,
OrganizationUserType = orgUser.Type,
OrganizationUserStatus = orgUser.Status,
OrganizationUserPermissionsData = orgUser.Permissions,
IsProvider = providerSet.Contains(policy.OrganizationId)
};

return await query.ToListAsync();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
CREATE PROCEDURE [dbo].[PolicyDetails_ReadByUserIdPolicyType]
@UserId UNIQUEIDENTIFIER,
@PolicyType TINYINT
AS
BEGIN
SET NOCOUNT ON

DECLARE @UserEmail NVARCHAR(256)
SELECT @UserEmail = Email
FROM
[dbo].[UserView]
WHERE
Id = @UserId

;WITH OrgUsers AS
(
-- Non-invited users (Status != 0): direct UserId match
SELECT
OU.[Id],
OU.[OrganizationId],
OU.[Type],
OU.[Status],
OU.[Permissions]
FROM
[dbo].[OrganizationUserView] OU
WHERE
OU.[Status] != 0
AND OU.[UserId] = @UserId

UNION ALL

-- Invited users (Status = 0): email match
SELECT
OU.[Id],
OU.[OrganizationId],
OU.[Type],
OU.[Status],
OU.[Permissions]
FROM
[dbo].[OrganizationUserView] OU
WHERE
OU.[Status] = 0
AND OU.[Email] = @UserEmail
AND @UserEmail IS NOT NULL
),
Providers AS
(
SELECT
OrganizationId
FROM
[dbo].[UserProviderAccessView]
WHERE
UserId = @UserId
)
SELECT
OU.[Id] AS OrganizationUserId,
P.[OrganizationId],
P.[Type] AS PolicyType,
P.[Data] AS PolicyData,
OU.[Type] AS OrganizationUserType,
OU.[Status] AS OrganizationUserStatus,
OU.[Permissions] AS OrganizationUserPermissionsData,
CASE WHEN PR.[OrganizationId] IS NULL THEN 0 ELSE 1 END AS IsProvider
FROM
[dbo].[PolicyView] P
INNER JOIN
OrgUsers OU ON P.[OrganizationId] = OU.[OrganizationId]
INNER JOIN
[dbo].[OrganizationView] O ON P.[OrganizationId] = O.[Id]
LEFT JOIN
Providers PR ON PR.[OrganizationId] = OU.[OrganizationId]
WHERE
P.[Type] = @PolicyType
AND P.[Enabled] = 1
AND O.[Enabled] = 1
AND O.[UsePolicies] = 1
END
8 changes: 4 additions & 4 deletions test/Api.Test/Auth/Controllers/WebAuthnControllerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public async Task AttestationOptions_WithPolicyRequirementsEnabled_CanUsePasskey
sutProvider.GetDependency<IUserService>().VerifySecretAsync(user, default).ReturnsForAnyArgs(true);
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).ReturnsForAnyArgs(true);
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<RequireSsoPolicyRequirement>(user.Id)
.GetAsyncVNext<RequireSsoPolicyRequirement>(user.Id)
.ReturnsForAnyArgs(new RequireSsoPolicyRequirement { CanUsePasskeyLogin = false });

// Act & Assert
Expand All @@ -123,7 +123,7 @@ public async Task AttestationOptions_WithPolicyRequirementsEnabled_CanUsePasskey
sutProvider.GetDependency<IUserService>().VerifySecretAsync(user, default).ReturnsForAnyArgs(true);
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).ReturnsForAnyArgs(true);
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<RequireSsoPolicyRequirement>(user.Id)
.GetAsyncVNext<RequireSsoPolicyRequirement>(user.Id)
.ReturnsForAnyArgs(new RequireSsoPolicyRequirement { CanUsePasskeyLogin = true });
sutProvider.GetDependency<IDataProtectorTokenFactory<WebAuthnCredentialCreateOptionsTokenable>>()
.Protect(Arg.Any<WebAuthnCredentialCreateOptionsTokenable>()).Returns("token");
Expand Down Expand Up @@ -329,7 +329,7 @@ public async Task Post_WithPolicyRequirementsEnabled_CanUsePasskeyLoginFalse_Thr
.Returns(token);
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).ReturnsForAnyArgs(true);
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<RequireSsoPolicyRequirement>(user.Id)
.GetAsyncVNext<RequireSsoPolicyRequirement>(user.Id)
.ReturnsForAnyArgs(new RequireSsoPolicyRequirement { CanUsePasskeyLogin = false });

// Act & Assert
Expand Down Expand Up @@ -359,7 +359,7 @@ public async Task Post_WithPolicyRequirementsEnabled_CanUsePasskeyLoginTrue_Succ
.Returns(token);
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).ReturnsForAnyArgs(true);
sutProvider.GetDependency<IPolicyRequirementQuery>()
.GetAsync<RequireSsoPolicyRequirement>(user.Id)
.GetAsyncVNext<RequireSsoPolicyRequirement>(user.Id)
.ReturnsForAnyArgs(new RequireSsoPolicyRequirement { CanUsePasskeyLogin = true });

// Act
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,60 @@ public async Task GetAsync_HandlesNoPolicies(Guid userId)
Assert.Empty(requirement.Policies);
}

[Theory, BitAutoData]
public async Task GetAsyncVNext_CallsEnforceCallback(Guid userId)
{
// Arrange policies
var policyRepository = Substitute.For<IPolicyRepository>();
var thisPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
var otherPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync(userId, PolicyType.SingleOrg)
.Returns([thisPolicy, otherPolicy]);

// Arrange a substitute Enforce function so that we can inspect the received calls
var callback = Substitute.For<Func<PolicyDetails, bool>>();
callback(Arg.Any<PolicyDetails>()).Returns(x => x.Arg<PolicyDetails>() == thisPolicy);

// Arrange the sut
var factory = new TestPolicyRequirementFactory(callback);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);

// Act
var requirement = await sut.GetAsyncVNext<TestPolicyRequirement>(userId);

// Assert
Assert.Contains(thisPolicy, requirement.Policies);
Assert.DoesNotContain(otherPolicy, requirement.Policies);
callback.Received()(Arg.Is(thisPolicy));
callback.Received()(Arg.Is(otherPolicy));
}

[Theory, BitAutoData]
public async Task GetAsyncVNext_ThrowsIfNoFactoryRegistered(Guid userId)
{
var policyRepository = Substitute.For<IPolicyRepository>();
var sut = new PolicyRequirementQuery(policyRepository, []);

var exception = await Assert.ThrowsAsync<NotImplementedException>(()
=> sut.GetAsyncVNext<TestPolicyRequirement>(userId));
Assert.Contains("No Requirement Factory found", exception.Message);
}

[Theory, BitAutoData]
public async Task GetAsyncVNext_HandlesNoPolicies(Guid userId)
{
var policyRepository = Substitute.For<IPolicyRepository>();
policyRepository.GetPolicyDetailsByUserIdAndPolicyTypeAsync(userId, PolicyType.SingleOrg)
.Returns([]);

var factory = new TestPolicyRequirementFactory(x => x.IsProvider);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);

var requirement = await sut.GetAsyncVNext<TestPolicyRequirement>(userId);

Assert.Empty(requirement.Policies);
}

[Theory, BitAutoData]
public async Task GetAsync_WithMultipleUserIds_ReturnsRequirementPerUser(Guid userIdA, Guid userIdB)
{
Expand Down
Loading
Loading