Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,57 +34,50 @@ public static async Task<T> ExecuteIdempotentAsync<T>(
CommandMetadata? metadata,
CancellationToken cancellationToken = default)
{
// Fast path: check for existing completed result
var existingResult = await store.GetCommandResultAsync<T>(commandId, cancellationToken);
if (existingResult != null)
while (true)
{
return existingResult;
}
cancellationToken.ThrowIfCancellationRequested();

// Atomically try to claim the command for execution
var (currentStatus, wasSet) = await store.GetAndSetStatusAsync(
commandId,
CommandExecutionStatus.InProgress,
cancellationToken);
var currentStatus = await store.GetCommandStatusAsync(commandId, cancellationToken);

switch (currentStatus)
{
case CommandExecutionStatus.Completed:
// Result exists but might have been evicted, get it again
existingResult = await store.GetCommandResultAsync<T>(commandId, cancellationToken);
return existingResult ?? throw new InvalidOperationException($"Command {commandId} marked as completed but result not found");

case CommandExecutionStatus.InProgress:
case CommandExecutionStatus.Processing:
// Another thread is executing, wait for completion
return await WaitForCompletionAsync<T>(store, commandId, cancellationToken);

case CommandExecutionStatus.Failed:
// Previous execution failed, we can retry (wasSet should be true)
if (!wasSet)
switch (currentStatus)
{
case CommandExecutionStatus.Completed:
{
// Race condition - another thread claimed it
return await WaitForCompletionAsync<T>(store, commandId, cancellationToken);
var cachedResult = await store.GetCommandResultAsync<T>(commandId, cancellationToken);
return cachedResult ?? default!;
}
break;

case CommandExecutionStatus.NotFound:
case CommandExecutionStatus.NotStarted:
default:
// First execution (wasSet should be true)
if (!wasSet)
{
// Race condition - another thread claimed it

case CommandExecutionStatus.InProgress:
case CommandExecutionStatus.Processing:
return await WaitForCompletionAsync<T>(store, commandId, cancellationToken);

case CommandExecutionStatus.NotFound:
case CommandExecutionStatus.NotStarted:
case CommandExecutionStatus.Failed:
default:
{
var claimed = await store.TrySetCommandStatusAsync(
commandId,
currentStatus,
CommandExecutionStatus.InProgress,
cancellationToken);

if (claimed)
{
goto ExecuteOperation;
}

break;
}
break;
}
}

// We successfully claimed the command for execution
ExecuteOperation:
try
{
var result = await operation();

// Store result and mark as completed atomically
await store.SetCommandResultAsync(commandId, result, cancellationToken);
await store.SetCommandStatusAsync(commandId, CommandExecutionStatus.Completed, cancellationToken);
Expand Down Expand Up @@ -159,7 +152,7 @@ public static async Task<T> ExecuteIdempotentWithRetryAsync<T>(
if (status == CommandExecutionStatus.Completed)
{
var result = await store.GetCommandResultAsync<T>(commandId, cancellationToken);
return (result != null, result);
return (true, result);
}

return (false, default);
Expand Down Expand Up @@ -192,17 +185,18 @@ public static async Task<Dictionary<string, T>> ExecuteBatchIdempotentAsync<T>(
var operationsList = operations.ToList();
var commandIds = operationsList.Select(op => op.commandId).ToList();

// Check for existing results in batch
var existingStatuses = await store.GetMultipleStatusAsync(commandIds, cancellationToken);
var existingResults = await store.GetMultipleResultsAsync<T>(commandIds, cancellationToken);
var results = new Dictionary<string, T>();
var pendingOperations = new List<(string commandId, Func<Task<T>> operation)>();

// Separate completed from pending
foreach (var (commandId, operation) in operationsList)
{
if (existingResults.TryGetValue(commandId, out var existingResult) && existingResult != null)
if (existingStatuses.TryGetValue(commandId, out var status) && status == CommandExecutionStatus.Completed)
{
results[commandId] = existingResult;
existingResults.TryGetValue(commandId, out var existingResult);
results[commandId] = existingResult ?? default!;
}
else
{
Expand Down Expand Up @@ -255,7 +249,7 @@ private static async Task<T> WaitForCompletionAsync<T>(
{
case CommandExecutionStatus.Completed:
var result = await store.GetCommandResultAsync<T>(commandId, cancellationToken);
return result ?? throw new InvalidOperationException($"Command {commandId} completed but result not found");
return result ?? default!;

case CommandExecutionStatus.Failed:
throw new InvalidOperationException($"Command {commandId} failed during execution");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,27 @@ public Task<CommandExecutionStatus> GetStatusAsync()

public async Task<bool> TryStartProcessingAsync()
{
// Check if already processing or completed
if (state.State.Status != CommandExecutionStatus.NotFound)
// Reject concurrent executions
switch (state.State.Status)
{
return false;
case CommandExecutionStatus.InProgress:
case CommandExecutionStatus.Processing:
case CommandExecutionStatus.Completed:
return false;

case CommandExecutionStatus.Failed:
state.State.Result = null;
state.State.ErrorMessage = null;
state.State.CompletedAt = null;
state.State.FailedAt = null;
break;

case CommandExecutionStatus.NotFound:
case CommandExecutionStatus.NotStarted:
break;

default:
return false;
}

state.State.Status = CommandExecutionStatus.Processing;
Expand All @@ -40,6 +57,47 @@ public async Task<bool> TryStartProcessingAsync()
return true;
}

public async Task<bool> TrySetStatusAsync(CommandExecutionStatus expectedStatus, CommandExecutionStatus newStatus)
{
if (state.State.Status != expectedStatus)
{
return false;
}

switch (newStatus)
{
case CommandExecutionStatus.InProgress:
case CommandExecutionStatus.Processing:
return await TryStartProcessingAsync();

case CommandExecutionStatus.Completed:
await MarkCompletedAsync(state.State.Result);
return true;

case CommandExecutionStatus.Failed:
await MarkFailedAsync(state.State.ErrorMessage ?? "Status set to failed");
return true;

case CommandExecutionStatus.NotFound:
await ClearAsync();
return true;

case CommandExecutionStatus.NotStarted:
state.State.Status = CommandExecutionStatus.NotStarted;
state.State.Result = null;
state.State.ErrorMessage = null;
state.State.StartedAt = null;
state.State.CompletedAt = null;
state.State.FailedAt = null;
state.State.ExpiresAt = null;
await state.WriteStateAsync();
return true;

default:
return false;
}
}

public async Task MarkCompletedAsync<TResult>(TResult result)
{
state.State.Status = CommandExecutionStatus.Completed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,12 @@ public interface ICommandIdempotencyGrain : IGrainWithStringKey
/// Clears the command state from the grain.
/// </summary>
Task ClearAsync();
}

/// <summary>
/// Attempts to transition the command to a new status when the current status matches the expected value.
/// </summary>
/// <param name="expectedStatus">The status the caller believes the command currently has.</param>
/// <param name="newStatus">The desired status to transition to.</param>
/// <returns><c>true</c> when the transition succeeds, otherwise <c>false</c>.</returns>
Task<bool> TrySetStatusAsync(CommandExecutionStatus expectedStatus, CommandExecutionStatus newStatus);
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ public async Task SetCommandStatusAsync(string commandId, CommandExecutionStatus
await grain.MarkFailedAsync("Status set to failed");
break;
case CommandExecutionStatus.Completed:
await grain.MarkCompletedAsync<object?>(null);
var (hasResult, result) = await grain.TryGetResultAsync();
if (hasResult)
{
await grain.MarkCompletedAsync(result);
}
else
{
await grain.MarkCompletedAsync<object?>(default);
}
break;
case CommandExecutionStatus.NotStarted:
case CommandExecutionStatus.NotFound:
Expand Down Expand Up @@ -80,14 +88,12 @@ public async Task RemoveCommandAsync(string commandId, CancellationToken cancell
public async Task<bool> TrySetCommandStatusAsync(string commandId, CommandExecutionStatus expectedStatus, CommandExecutionStatus newStatus, CancellationToken cancellationToken = default)
{
var grain = _grainFactory.GetGrain<ICommandIdempotencyGrain>(commandId);
var currentStatus = await grain.GetStatusAsync();

if (currentStatus == expectedStatus)

if (await grain.TrySetStatusAsync(expectedStatus, newStatus))
{
await SetCommandStatusAsync(commandId, newStatus, cancellationToken);
return true;
}

return false;
}

Expand Down Expand Up @@ -175,7 +181,21 @@ public async Task MarkFailedAsync(Guid commandId, string errorMessage, Cancellat

public async Task<(bool success, TResult? result)> TryGetResultAsync<TResult>(Guid commandId, CancellationToken cancellationToken = default)
{
var result = await GetCommandResultAsync<TResult>(commandId.ToString(), cancellationToken);
return (result != null, result);
var grain = _grainFactory.GetGrain<ICommandIdempotencyGrain>(commandId.ToString());
var status = await grain.GetStatusAsync();

if (status != CommandExecutionStatus.Completed)
{
return (false, default);
}

var (_, result) = await grain.TryGetResultAsync();

if (result is TResult typedResult)
{
return (true, typedResult);
}

return (true, default);
}
}
Loading