diff --git a/pkg/codec/decode.go b/pkg/codec/decode.go
index 251f92c4..81c84e7f 100644
--- a/pkg/codec/decode.go
+++ b/pkg/codec/decode.go
@@ -13,6 +13,7 @@ import (
raptorq "github.com/LumeraProtocol/rq-go"
"github.com/LumeraProtocol/supernode/v2/pkg/errors"
"github.com/LumeraProtocol/supernode/v2/pkg/logtrace"
+ "github.com/google/uuid"
)
type DecodeRequest struct {
@@ -29,7 +30,7 @@ type DecodeResponse struct {
// Workspace holds paths & reverse index for prepared decoding.
type Workspace struct {
ActionID string
- SymbolsDir string // ...//
+ SymbolsDir string // ...//downloads//
BlockDirs []string // index = blockID (or 0 if single block)
symbolToBlock map[string]int
mu sync.RWMutex // protects symbolToBlock reads if you expand it later
@@ -51,8 +52,12 @@ func (rq *raptorQ) PrepareDecode(
}
logtrace.Info(ctx, "rq: prepare-decode start", fields)
- // Create root symbols dir for this action
- symbolsDir := filepath.Join(rq.symbolsBaseDir, actionID)
+ // Create per-request workspace under /downloads//
+ base := rq.symbolsBaseDir
+ if base == "" {
+ base = os.TempDir()
+ }
+ symbolsDir := filepath.Join(base, "downloads", actionID, uuid.NewString())
if err := os.MkdirAll(symbolsDir, 0o755); err != nil {
fields[logtrace.FieldError] = err.Error()
logtrace.Error(ctx, "mkdir symbols base dir failed", fields)
diff --git a/pkg/codec/decode_workspace_test.go b/pkg/codec/decode_workspace_test.go
new file mode 100644
index 00000000..4a8a664c
--- /dev/null
+++ b/pkg/codec/decode_workspace_test.go
@@ -0,0 +1,51 @@
+package codec
+
+import (
+ "context"
+ "os"
+ "testing"
+)
+
+func TestPrepareDecode_UniqueWorkspacePerCall(t *testing.T) {
+ base := t.TempDir()
+ c := NewRaptorQCodec(base)
+ layout := Layout{Blocks: []Block{{BlockID: 0, Symbols: []string{"s1"}}}}
+
+ _, _, cleanup1, ws1, err := c.PrepareDecode(context.Background(), "actionA", layout)
+ if err != nil {
+ t.Fatalf("prepare decode 1: %v", err)
+ }
+ t.Cleanup(func() { _ = cleanup1() })
+ if ws1 == nil || ws1.SymbolsDir == "" {
+ t.Fatalf("prepare decode 1 returned empty workspace")
+ }
+ if _, err := os.Stat(ws1.SymbolsDir); err != nil {
+ t.Fatalf("stat ws1: %v", err)
+ }
+
+ _, _, cleanup2, ws2, err := c.PrepareDecode(context.Background(), "actionA", layout)
+ if err != nil {
+ t.Fatalf("prepare decode 2: %v", err)
+ }
+ t.Cleanup(func() { _ = cleanup2() })
+ if ws2 == nil || ws2.SymbolsDir == "" {
+ t.Fatalf("prepare decode 2 returned empty workspace")
+ }
+ if _, err := os.Stat(ws2.SymbolsDir); err != nil {
+ t.Fatalf("stat ws2: %v", err)
+ }
+
+ if ws1.SymbolsDir == ws2.SymbolsDir {
+ t.Fatalf("expected unique workspace per call; got same dir: %s", ws1.SymbolsDir)
+ }
+
+ if err := cleanup1(); err != nil {
+ t.Fatalf("cleanup 1: %v", err)
+ }
+ if _, err := os.Stat(ws1.SymbolsDir); !os.IsNotExist(err) {
+ t.Fatalf("expected ws1 removed; stat err=%v", err)
+ }
+ if _, err := os.Stat(ws2.SymbolsDir); err != nil {
+ t.Fatalf("expected ws2 still present after ws1 cleanup; stat err=%v", err)
+ }
+}
diff --git a/pkg/task/handle.go b/pkg/task/handle.go
index 74f6e406..4cc03090 100644
--- a/pkg/task/handle.go
+++ b/pkg/task/handle.go
@@ -2,12 +2,15 @@ package task
import (
"context"
+ "errors"
"sync"
"time"
"github.com/LumeraProtocol/supernode/v2/pkg/logtrace"
)
+var ErrAlreadyRunning = errors.New("task already running")
+
// Handle manages a running task with an optional watchdog.
// It ensures Start and End are paired, logs start/end, and auto-ends on timeout.
type Handle struct {
@@ -41,6 +44,37 @@ func StartWith(tr Tracker, ctx context.Context, service, id string, timeout time
return g
}
+// StartUniqueWith starts tracking a task only if it's not already tracked for the same
+// (service, id) pair. It returns ErrAlreadyRunning if the task is already in-flight.
+func StartUniqueWith(tr Tracker, ctx context.Context, service, id string, timeout time.Duration) (*Handle, error) {
+ if tr == nil || service == "" || id == "" {
+ return &Handle{}, nil
+ }
+
+ if ts, ok := tr.(interface {
+ TryStart(service, taskID string) bool
+ }); ok {
+ if !ts.TryStart(service, id) {
+ return nil, ErrAlreadyRunning
+ }
+ } else { // fallback: can't enforce uniqueness with unknown Tracker implementations
+ tr.Start(service, id)
+ }
+
+ logtrace.Info(ctx, "task: started", logtrace.Fields{"service": service, "task_id": id})
+ g := &Handle{tr: tr, service: service, id: id, stop: make(chan struct{})}
+ if timeout > 0 {
+ go func() {
+ select {
+ case <-time.After(timeout):
+ g.endWith(ctx, true)
+ case <-g.stop:
+ }
+ }()
+ }
+ return g, nil
+}
+
// End stops tracking the task. Safe to call multiple times.
func (g *Handle) End(ctx context.Context) {
g.endWith(ctx, false)
diff --git a/pkg/task/task.go b/pkg/task/task.go
index 8d0c0052..6b2f98a3 100644
--- a/pkg/task/task.go
+++ b/pkg/task/task.go
@@ -30,6 +30,28 @@ func New() *InMemoryTracker {
return &InMemoryTracker{data: make(map[string]map[string]struct{})}
}
+// TryStart attempts to mark a task as running under a given service.
+// It returns true if the task was newly started, or false if it was already running
+// (or if inputs are invalid). This is useful for "only one in-flight task" guards.
+func (t *InMemoryTracker) TryStart(service, taskID string) bool {
+ if service == "" || taskID == "" {
+ return false
+ }
+ t.mu.Lock()
+ m, ok := t.data[service]
+ if !ok {
+ m = make(map[string]struct{})
+ t.data[service] = m
+ }
+ if _, exists := m[taskID]; exists {
+ t.mu.Unlock()
+ return false
+ }
+ m[taskID] = struct{}{}
+ t.mu.Unlock()
+ return true
+}
+
// Start marks a task as running under a given service. Empty arguments
// are ignored. Calling Start with the same (service, taskID) pair is idempotent.
func (t *InMemoryTracker) Start(service, taskID string) {
diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go
index 1550bc37..0e1c660a 100644
--- a/pkg/task/task_test.go
+++ b/pkg/task/task_test.go
@@ -2,6 +2,7 @@ package task
import (
"context"
+ "errors"
"sync"
"testing"
"time"
@@ -155,3 +156,27 @@ func TestHandleIdempotentAndWatchdog(t *testing.T) {
}
}
}
+
+func TestStartUniqueWith_PreventsDuplicates(t *testing.T) {
+ tr := New()
+ ctx := context.Background()
+
+ h1, err := StartUniqueWith(tr, ctx, "svc.unique", "id-1", 0)
+ if err != nil {
+ t.Fatalf("StartUniqueWith 1: %v", err)
+ }
+ t.Cleanup(func() { h1.End(ctx) })
+
+ h2, err := StartUniqueWith(tr, ctx, "svc.unique", "id-1", 0)
+ if !errors.Is(err, ErrAlreadyRunning) {
+ t.Fatalf("expected ErrAlreadyRunning, got handle=%v err=%v", h2, err)
+ }
+
+ // After ending, it should be startable again.
+ h1.End(ctx)
+ h3, err := StartUniqueWith(tr, ctx, "svc.unique", "id-1", 0)
+ if err != nil {
+ t.Fatalf("StartUniqueWith 2: %v", err)
+ }
+ h3.End(ctx)
+}
diff --git a/supernode/adaptors/p2p.go b/supernode/adaptors/p2p.go
index 31184fd7..ce218a4f 100644
--- a/supernode/adaptors/p2p.go
+++ b/supernode/adaptors/p2p.go
@@ -134,7 +134,7 @@ func (p *p2pImpl) storeCascadeSymbolsAndData(ctx context.Context, taskID, action
}
start = end
}
- if err := p.rqStore.UpdateIsFirstBatchStored(actionID); err != nil {
+ if err := p.rqStore.UpdateIsFirstBatchStored(taskID); err != nil {
return totalSymbols, totalAvailable, fmt.Errorf("update first-batch flag: %w", err)
}
return totalSymbols, totalAvailable, nil
diff --git a/supernode/cascade/download.go b/supernode/cascade/download.go
index aa61a82a..02ab879d 100644
--- a/supernode/cascade/download.go
+++ b/supernode/cascade/download.go
@@ -48,7 +48,9 @@ func (task *CascadeRegistrationTask) Download(ctx context.Context, req *Download
return task.wrapErr(ctx, "failed to get action", err, fields)
}
logtrace.Info(ctx, "download: action fetched", fields)
- task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "Action retrieved", "", "", send)
+ if err := task.streamDownloadEvent(ctx, SupernodeEventTypeActionRetrieved, "Action retrieved", "", "", send); err != nil {
+ return err
+ }
if actionDetails.GetAction().State != actiontypes.ActionStateDone {
err = errors.New("action is not in a valid state")
@@ -64,7 +66,9 @@ func (task *CascadeRegistrationTask) Download(ctx context.Context, req *Download
return task.wrapErr(ctx, "error decoding cascade metadata", err, fields)
}
logtrace.Info(ctx, "download: metadata decoded", fields)
- task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", "", send)
+ if err := task.streamDownloadEvent(ctx, SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", "", send); err != nil {
+ return err
+ }
if !metadata.Public {
if req.Signature == "" {
@@ -80,7 +84,9 @@ func (task *CascadeRegistrationTask) Download(ctx context.Context, req *Download
logtrace.Info(ctx, "download: public cascade (no signature)", fields)
}
- task.streamDownloadEvent(SupernodeEventTypeNetworkRetrieveStarted, "Network retrieval started", "", "", send)
+ if err := task.streamDownloadEvent(ctx, SupernodeEventTypeNetworkRetrieveStarted, "Network retrieval started", "", "", send); err != nil {
+ return err
+ }
logtrace.Info(ctx, "download: network retrieval start", logtrace.Fields{logtrace.FieldActionID: actionDetails.GetAction().ActionID})
filePath, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields, send)
@@ -91,10 +97,20 @@ func (task *CascadeRegistrationTask) Download(ctx context.Context, req *Download
logtrace.Warn(ctx, "cleanup of tmp dir after error failed", logtrace.Fields{"tmp_dir": tmpDir, logtrace.FieldError: cerr.Error()})
}
}
+ if ctx.Err() != nil {
+ return ctx.Err()
+ }
return task.wrapErr(ctx, "failed to download artifacts", err, fields)
}
logtrace.Debug(ctx, "File reconstructed and hash verified", fields)
- task.streamDownloadEvent(SupernodeEventTypeDecodeCompleted, "Decode completed", filePath, tmpDir, send)
+ if err := task.streamDownloadEvent(ctx, SupernodeEventTypeDecodeCompleted, "Decode completed", filePath, tmpDir, send); err != nil {
+ if tmpDir != "" {
+ if cerr := task.CleanupDownload(ctx, tmpDir); cerr != nil {
+ logtrace.Warn(ctx, "cleanup of tmp dir after stream failure failed", logtrace.Fields{"tmp_dir": tmpDir, logtrace.FieldError: cerr.Error()})
+ }
+ }
+ return err
+ }
return nil
}
@@ -127,8 +143,11 @@ func (task *CascadeRegistrationTask) VerifyDownloadSignature(ctx context.Context
return nil
}
-func (task *CascadeRegistrationTask) streamDownloadEvent(eventType SupernodeEventType, msg, filePath, dir string, send func(resp *DownloadResponse) error) {
- _ = send(&DownloadResponse{EventType: eventType, Message: msg, FilePath: filePath, DownloadedDir: dir})
+func (task *CascadeRegistrationTask) streamDownloadEvent(ctx context.Context, eventType SupernodeEventType, msg, filePath, dir string, send func(resp *DownloadResponse) error) error {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ return send(&DownloadResponse{EventType: eventType, Message: msg, FilePath: filePath, DownloadedDir: dir})
}
func (task *CascadeRegistrationTask) downloadArtifacts(ctx context.Context, actionID string, metadata actiontypes.CascadeMetadata, fields logtrace.Fields, send func(resp *DownloadResponse) error) (string, string, error) {
@@ -244,7 +263,9 @@ func (task *CascadeRegistrationTask) restoreFileFromLayoutDeprecated(ctx context
// Emit minimal JSON payload (metrics system removed)
info := map[string]interface{}{"action_id": actionID, "found_symbols": len(symbols), "target_percent": targetRequiredPercent}
if b, err := json.Marshal(info); err == nil {
- task.streamDownloadEvent(SupernodeEventTypeArtefactsDownloaded, string(b), decodeInfo.FilePath, decodeInfo.DecodeTmpDir, send)
+ if err := task.streamDownloadEvent(ctx, SupernodeEventTypeArtefactsDownloaded, string(b), decodeInfo.FilePath, decodeInfo.DecodeTmpDir, send); err != nil {
+ return "", decodeInfo.DecodeTmpDir, err
+ }
}
return decodeInfo.FilePath, decodeInfo.DecodeTmpDir, nil
}
@@ -399,7 +420,9 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout(
// Event
info := map[string]interface{}{"action_id": actionID, "found_symbols": int(atomic.LoadInt32(&written)), "target_percent": targetRequiredPercent}
if b, err := json.Marshal(info); err == nil {
- task.streamDownloadEvent(SupernodeEventTypeArtefactsDownloaded, string(b), decodeInfo.FilePath, decodeInfo.DecodeTmpDir, send)
+ if err := task.streamDownloadEvent(ctx, SupernodeEventTypeArtefactsDownloaded, string(b), decodeInfo.FilePath, decodeInfo.DecodeTmpDir, send); err != nil {
+ return "", decodeInfo.DecodeTmpDir, err
+ }
}
success = true
diff --git a/supernode/cascade/helper.go b/supernode/cascade/helper.go
index 394b59f1..b5537d24 100644
--- a/supernode/cascade/helper.go
+++ b/supernode/cascade/helper.go
@@ -165,13 +165,13 @@ func (task *CascadeRegistrationTask) wrapErr(ctx context.Context, msg string, er
return status.Errorf(codes.Internal, "%s", msg)
}
-func (task *CascadeRegistrationTask) emitArtefactsStored(ctx context.Context, fields logtrace.Fields, _ codec.Layout, send func(resp *RegisterResponse) error) {
+func (task *CascadeRegistrationTask) emitArtefactsStored(ctx context.Context, fields logtrace.Fields, _ codec.Layout, send func(resp *RegisterResponse) error) error {
if fields == nil {
fields = logtrace.Fields{}
}
msg := "Artefacts stored"
logtrace.Info(ctx, "register: artefacts stored", fields)
- task.streamEvent(SupernodeEventTypeArtefactsStored, msg, "", send)
+ return task.streamEvent(ctx, SupernodeEventTypeArtefactsStored, msg, "", send)
}
func (task *CascadeRegistrationTask) verifyActionFee(ctx context.Context, action *actiontypes.Action, dataSize int, fields logtrace.Fields) error {
diff --git a/supernode/cascade/register.go b/supernode/cascade/register.go
index 926f9b31..1693d61f 100644
--- a/supernode/cascade/register.go
+++ b/supernode/cascade/register.go
@@ -59,14 +59,18 @@ func (task *CascadeRegistrationTask) Register(
fields[logtrace.FieldStatus] = action.State
fields[logtrace.FieldPrice] = action.Price
logtrace.Info(ctx, "register: action fetched", fields)
- task.streamEvent(SupernodeEventTypeActionRetrieved, "Action retrieved", "", send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeActionRetrieved, "Action retrieved", "", send); err != nil {
+ return err
+ }
// Step 4: Verify action fee based on data size (rounded up to KB)
if err := task.verifyActionFee(ctx, action, req.DataSize, fields); err != nil {
return err
}
logtrace.Info(ctx, "register: fee verified", fields)
- task.streamEvent(SupernodeEventTypeActionFeeVerified, "Action fee verified", "", send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeActionFeeVerified, "Action fee verified", "", send); err != nil {
+ return err
+ }
// Step 5: Ensure this node is eligible (top supernode for block)
fields[logtrace.FieldSupernodeState] = task.SupernodeAccountAddress
@@ -74,7 +78,9 @@ func (task *CascadeRegistrationTask) Register(
return err
}
logtrace.Info(ctx, "register: top supernode confirmed", fields)
- task.streamEvent(SupernodeEventTypeTopSupernodeCheckPassed, "Top supernode eligibility confirmed", "", send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeTopSupernodeCheckPassed, "Top supernode eligibility confirmed", "", send); err != nil {
+ return err
+ }
// Step 6: Decode Cascade metadata from the action
cascadeMeta, err := cascadekit.UnmarshalCascadeMetadata(action.Metadata)
@@ -82,7 +88,9 @@ func (task *CascadeRegistrationTask) Register(
return task.wrapErr(ctx, "failed to unmarshal cascade metadata", err, fields)
}
logtrace.Info(ctx, "register: metadata decoded", fields)
- task.streamEvent(SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", send); err != nil {
+ return err
+ }
// Step 7: Verify request-provided data hash matches metadata
if err := cascadekit.VerifyB64DataHash(req.DataHash, cascadeMeta.DataHash); err != nil {
@@ -90,7 +98,9 @@ func (task *CascadeRegistrationTask) Register(
}
logtrace.Debug(ctx, "request data-hash has been matched with the action data-hash", fields)
logtrace.Info(ctx, "register: data hash matched", fields)
- task.streamEvent(SupernodeEventTypeDataHashVerified, "Data hash verified", "", send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeDataHashVerified, "Data hash verified", "", send); err != nil {
+ return err
+ }
// Step 8: Encode input using the RQ codec to produce layout and symbols
encodeResult, err := task.encodeInput(ctx, req.ActionID, req.FilePath, fields)
@@ -99,7 +109,9 @@ func (task *CascadeRegistrationTask) Register(
}
fields["symbols_dir"] = encodeResult.SymbolsDir
logtrace.Info(ctx, "register: input encoded", fields)
- task.streamEvent(SupernodeEventTypeInputEncoded, "Input encoded", "", send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeInputEncoded, "Input encoded", "", send); err != nil {
+ return err
+ }
// Step 9: Verify index and layout signatures; produce layoutB64
logtrace.Info(ctx, "register: verify+decode layout start", fields)
@@ -109,7 +121,9 @@ func (task *CascadeRegistrationTask) Register(
}
layoutSignatureB64 := indexFile.LayoutSignature
logtrace.Info(ctx, "register: signature verified", fields)
- task.streamEvent(SupernodeEventTypeSignatureVerified, "Signature verified", "", send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeSignatureVerified, "Signature verified", "", send); err != nil {
+ return err
+ }
// Step 10: Generate RQID files (layout and index) and compute IDs
rqIDs, idFiles, err := task.generateRQIDFiles(ctx, cascadeMeta, layoutSignatureB64, layoutB64, fields)
@@ -129,26 +143,36 @@ func (task *CascadeRegistrationTask) Register(
fields["combined_files_size_kb"] = float64(totalSize) / 1024
fields["combined_files_size_mb"] = float64(totalSize) / (1024 * 1024)
logtrace.Info(ctx, "register: rqid files generated", fields)
- task.streamEvent(SupernodeEventTypeRQIDsGenerated, "RQID files generated", "", send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeRQIDsGenerated, "RQID files generated", "", send); err != nil {
+ return err
+ }
logtrace.Info(ctx, "register: rqids validated", fields)
- task.streamEvent(SupernodeEventTypeRqIDsVerified, "RQIDs verified", "", send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeRqIDsVerified, "RQIDs verified", "", send); err != nil {
+ return err
+ }
// Step 11: Simulate finalize to ensure the tx will succeed
if _, err := task.LumeraClient.SimulateFinalizeAction(ctx, action.ActionID, rqIDs); err != nil {
fields[logtrace.FieldError] = err.Error()
logtrace.Info(ctx, "register: finalize simulation failed", fields)
- task.streamEvent(SupernodeEventTypeFinalizeSimulationFailed, "Finalize simulation failed", "", send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeFinalizeSimulationFailed, "Finalize simulation failed", "", send); err != nil {
+ return err
+ }
return task.wrapErr(ctx, "finalize action simulation failed", err, fields)
}
logtrace.Info(ctx, "register: finalize simulation passed", fields)
- task.streamEvent(SupernodeEventTypeFinalizeSimulated, "Finalize simulation passed", "", send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeFinalizeSimulated, "Finalize simulation passed", "", send); err != nil {
+ return err
+ }
// Step 12: Store artefacts to the network store
if err := task.storeArtefacts(ctx, action.ActionID, idFiles, encodeResult.SymbolsDir, fields); err != nil {
return err
}
- task.emitArtefactsStored(ctx, fields, encodeResult.Layout, send)
+ if err := task.emitArtefactsStored(ctx, fields, encodeResult.Layout, send); err != nil {
+ return err
+ }
// Step 13: Finalize the action on-chain
resp, err := task.LumeraClient.FinalizeAction(ctx, action.ActionID, rqIDs)
@@ -160,6 +184,8 @@ func (task *CascadeRegistrationTask) Register(
txHash := resp.TxResponse.TxHash
fields[logtrace.FieldTxHash] = txHash
logtrace.Info(ctx, "register: action finalized", fields)
- task.streamEvent(SupernodeEventTypeActionFinalized, "Action finalized", txHash, send)
+ if err := task.streamEvent(ctx, SupernodeEventTypeActionFinalized, "Action finalized", txHash, send); err != nil {
+ return err
+ }
return nil
}
diff --git a/supernode/cascade/stream_send_error_test.go b/supernode/cascade/stream_send_error_test.go
new file mode 100644
index 00000000..e1b32e7a
--- /dev/null
+++ b/supernode/cascade/stream_send_error_test.go
@@ -0,0 +1,83 @@
+package cascade
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ actiontypes "github.com/LumeraProtocol/lumera/x/action/v1/types"
+ sntypes "github.com/LumeraProtocol/lumera/x/supernode/v1/types"
+ sdktx "github.com/cosmos/cosmos-sdk/types/tx"
+)
+
+type stubLumeraClient struct {
+ action *actiontypes.Action
+}
+
+func (s *stubLumeraClient) GetAction(_ context.Context, _ string) (*actiontypes.QueryGetActionResponse, error) {
+ return &actiontypes.QueryGetActionResponse{Action: s.action}, nil
+}
+
+func (s *stubLumeraClient) GetTopSupernodes(context.Context, uint64) (*sntypes.QueryGetTopSuperNodesForBlockResponse, error) {
+ panic("unexpected call")
+}
+
+func (s *stubLumeraClient) Verify(context.Context, string, []byte, []byte) error {
+ panic("unexpected call")
+}
+
+func (s *stubLumeraClient) GetActionFee(context.Context, string) (*actiontypes.QueryGetActionFeeResponse, error) {
+ panic("unexpected call")
+}
+
+func (s *stubLumeraClient) SimulateFinalizeAction(context.Context, string, []string) (*sdktx.SimulateResponse, error) {
+ panic("unexpected call")
+}
+
+func (s *stubLumeraClient) FinalizeAction(context.Context, string, []string) (*sdktx.BroadcastTxResponse, error) {
+ panic("unexpected call")
+}
+
+func TestRegister_AbortsOnEventSendError(t *testing.T) {
+ sendErr := errors.New("send failed")
+ sendCalls := 0
+
+ service := &CascadeService{
+ LumeraClient: &stubLumeraClient{action: &actiontypes.Action{ActionID: "action123"}},
+ }
+ task := NewCascadeRegistrationTask(service)
+
+ err := task.Register(context.Background(), &RegisterRequest{TaskID: "task123", ActionID: "action123"}, func(*RegisterResponse) error {
+ sendCalls++
+ return sendErr
+ })
+
+ if !errors.Is(err, sendErr) {
+ t.Fatalf("expected send error; got %v", err)
+ }
+ if sendCalls != 1 {
+ t.Fatalf("expected 1 send call, got %d", sendCalls)
+ }
+}
+
+func TestDownload_AbortsOnEventSendError(t *testing.T) {
+ sendErr := errors.New("send failed")
+ sendCalls := 0
+
+ service := &CascadeService{
+ LumeraClient: &stubLumeraClient{action: &actiontypes.Action{ActionID: "action123", State: actiontypes.ActionStateDone}},
+ }
+ task := NewCascadeRegistrationTask(service)
+
+ err := task.Download(context.Background(), &DownloadRequest{ActionID: "action123"}, func(*DownloadResponse) error {
+ sendCalls++
+ return sendErr
+ })
+
+ if !errors.Is(err, sendErr) {
+ t.Fatalf("expected send error; got %v", err)
+ }
+ if sendCalls != 1 {
+ t.Fatalf("expected 1 send call, got %d", sendCalls)
+ }
+}
diff --git a/supernode/cascade/task.go b/supernode/cascade/task.go
index 71725d20..8539eee7 100644
--- a/supernode/cascade/task.go
+++ b/supernode/cascade/task.go
@@ -1,5 +1,7 @@
package cascade
+import "context"
+
// CascadeRegistrationTask is the task for cascade registration
type CascadeRegistrationTask struct {
*CascadeService
@@ -15,6 +17,10 @@ func NewCascadeRegistrationTask(service *CascadeService) *CascadeRegistrationTas
}
// streamEvent sends a RegisterResponse via the provided callback.
-func (task *CascadeRegistrationTask) streamEvent(eventType SupernodeEventType, msg, txHash string, send func(resp *RegisterResponse) error) {
- _ = send(&RegisterResponse{EventType: eventType, Message: msg, TxHash: txHash})
+// It propagates send failures so callers can abort work when the downstream is gone.
+func (task *CascadeRegistrationTask) streamEvent(ctx context.Context, eventType SupernodeEventType, msg, txHash string, send func(resp *RegisterResponse) error) error {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ return send(&RegisterResponse{EventType: eventType, Message: msg, TxHash: txHash})
}
diff --git a/supernode/transport/grpc/cascade/chunksize.go b/supernode/transport/grpc/cascade/chunksize.go
new file mode 100644
index 00000000..86cda1d8
--- /dev/null
+++ b/supernode/transport/grpc/cascade/chunksize.go
@@ -0,0 +1,34 @@
+package cascade
+
+// calculateOptimalChunkSize returns an optimal chunk size based on file size
+// to balance throughput and memory usage
+func calculateOptimalChunkSize(fileSize int64) int {
+ const (
+ minChunkSize = 64 * 1024 // 64 KB minimum
+ maxChunkSize = 4 * 1024 * 1024 // 4 MB maximum for 1GB+ files
+ smallFileThreshold = 1024 * 1024 // 1 MB
+ mediumFileThreshold = 50 * 1024 * 1024 // 50 MB
+ largeFileThreshold = 500 * 1024 * 1024 // 500 MB
+ )
+
+ var chunkSize int
+
+ switch {
+ case fileSize <= smallFileThreshold:
+ chunkSize = minChunkSize
+ case fileSize <= mediumFileThreshold:
+ chunkSize = 256 * 1024
+ case fileSize <= largeFileThreshold:
+ chunkSize = 1024 * 1024
+ default:
+ chunkSize = maxChunkSize
+ }
+
+ if chunkSize < minChunkSize {
+ chunkSize = minChunkSize
+ }
+ if chunkSize > maxChunkSize {
+ chunkSize = maxChunkSize
+ }
+ return chunkSize
+}
diff --git a/supernode/transport/grpc/cascade/download.go b/supernode/transport/grpc/cascade/download.go
new file mode 100644
index 00000000..7e25b23e
--- /dev/null
+++ b/supernode/transport/grpc/cascade/download.go
@@ -0,0 +1,108 @@
+package cascade
+
+import (
+ "fmt"
+ "io"
+ "os"
+
+ pb "github.com/LumeraProtocol/supernode/v2/gen/supernode/action/cascade"
+ "github.com/LumeraProtocol/supernode/v2/pkg/logtrace"
+ tasks "github.com/LumeraProtocol/supernode/v2/pkg/task"
+ cascadeService "github.com/LumeraProtocol/supernode/v2/supernode/cascade"
+)
+
+func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeService_DownloadServer) error {
+ ctx := stream.Context()
+ fields := logtrace.Fields{
+ logtrace.FieldMethod: "Download",
+ logtrace.FieldModule: "CascadeActionServer",
+ logtrace.FieldActionID: req.GetActionId(),
+ }
+ logtrace.Debug(ctx, "download request received", fields)
+
+ // Start live task tracking for the entire download RPC (including file streaming)
+ dlHandle := tasks.StartWith(server.tracker, ctx, serviceCascadeDownload, req.GetActionId(), server.downloadTimeout)
+ defer dlHandle.End(ctx)
+
+ // Prepare to capture decoded file path from task events
+ var decodedFilePath string
+ var tmpDir string
+
+ task := server.factory.NewCascadeRegistrationTask()
+ defer func() {
+ if tmpDir == "" {
+ return
+ }
+ if cerr := task.CleanupDownload(ctx, tmpDir); cerr != nil {
+ logtrace.Warn(ctx, "cleanup of tmp dir failed", logtrace.Fields{"tmp_dir": tmpDir, logtrace.FieldError: cerr.Error()})
+ }
+ }()
+ // Run cascade task Download; stream events back to client
+ err := task.Download(ctx, &cascadeService.DownloadRequest{ActionID: req.GetActionId(), Signature: req.GetSignature()}, func(resp *cascadeService.DownloadResponse) error {
+ // Forward event to gRPC client
+ evt := &pb.DownloadResponse{
+ ResponseType: &pb.DownloadResponse_Event{
+ Event: &pb.DownloadEvent{
+ EventType: pb.SupernodeEventType(resp.EventType),
+ Message: resp.Message,
+ },
+ },
+ }
+ if sendErr := stream.Send(evt); sendErr != nil {
+ return sendErr
+ }
+ // Capture decode-completed info for streaming
+ if resp.EventType == cascadeService.SupernodeEventTypeDecodeCompleted {
+ decodedFilePath = resp.FilePath
+ tmpDir = resp.DownloadedDir
+ }
+ return nil
+ })
+ if err != nil {
+ fields[logtrace.FieldError] = err.Error()
+ logtrace.Error(ctx, "download task failed", fields)
+ return fmt.Errorf("download task failed: %w", err)
+ }
+
+ if decodedFilePath == "" {
+ logtrace.Warn(ctx, "decode completed without file path", fields)
+ return nil
+ }
+
+ // Notify client that server is ready to stream the file
+ logtrace.Debug(ctx, "download: serve ready", logtrace.Fields{"event_type": cascadeService.SupernodeEventTypeServeReady, logtrace.FieldActionID: req.GetActionId()})
+ if err := stream.Send(&pb.DownloadResponse{ResponseType: &pb.DownloadResponse_Event{Event: &pb.DownloadEvent{EventType: pb.SupernodeEventType_SERVE_READY, Message: "Serve ready"}}}); err != nil {
+ return fmt.Errorf("send serve-ready: %w", err)
+ }
+
+ // Stream file content in chunks
+ fi, err := os.Stat(decodedFilePath)
+ if err != nil {
+ return fmt.Errorf("stat decoded file: %w", err)
+ }
+ chunkSize := calculateOptimalChunkSize(fi.Size())
+ f, err := os.Open(decodedFilePath)
+ if err != nil {
+ return fmt.Errorf("open decoded file: %w", err)
+ }
+ defer f.Close()
+
+ buf := make([]byte, chunkSize)
+ for {
+ n, rerr := f.Read(buf)
+ if n > 0 {
+ if err := stream.Send(&pb.DownloadResponse{ResponseType: &pb.DownloadResponse_Chunk{Chunk: &pb.DataChunk{Data: buf[:n]}}}); err != nil {
+ return fmt.Errorf("send chunk: %w", err)
+ }
+ }
+ if rerr == io.EOF {
+ break
+ }
+ if rerr != nil {
+ return fmt.Errorf("read decoded file: %w", rerr)
+ }
+ }
+
+ logtrace.Debug(ctx, "download stream completed", fields)
+ return nil
+}
diff --git a/supernode/transport/grpc/cascade/handler.go b/supernode/transport/grpc/cascade/handler.go
index 96237b98..f2924ffc 100644
--- a/supernode/transport/grpc/cascade/handler.go
+++ b/supernode/transport/grpc/cascade/handler.go
@@ -1,20 +1,11 @@
package cascade
import (
- "encoding/hex"
- "fmt"
- "hash"
- "io"
- "os"
- "path/filepath"
"time"
pb "github.com/LumeraProtocol/supernode/v2/gen/supernode/action/cascade"
- "github.com/LumeraProtocol/supernode/v2/pkg/errors"
- "github.com/LumeraProtocol/supernode/v2/pkg/logtrace"
tasks "github.com/LumeraProtocol/supernode/v2/pkg/task"
cascadeService "github.com/LumeraProtocol/supernode/v2/supernode/cascade"
- "lukechampine.com/blake3"
)
type ActionServer struct {
@@ -40,317 +31,3 @@ func NewCascadeActionServer(factory cascadeService.CascadeServiceFactory, tracke
}
return &ActionServer{factory: factory, tracker: tracker, uploadTimeout: uploadTO, downloadTimeout: downloadTO}
}
-
-// calculateOptimalChunkSize returns an optimal chunk size based on file size
-// to balance throughput and memory usage
-
-var (
- startedTask bool
- handle *tasks.Handle
-)
-
-func calculateOptimalChunkSize(fileSize int64) int {
- const (
- minChunkSize = 64 * 1024 // 64 KB minimum
- maxChunkSize = 4 * 1024 * 1024 // 4 MB maximum for 1GB+ files
- smallFileThreshold = 1024 * 1024 // 1 MB
- mediumFileThreshold = 50 * 1024 * 1024 // 50 MB
- largeFileThreshold = 500 * 1024 * 1024 // 500 MB
- )
-
- var chunkSize int
-
- switch {
- case fileSize <= smallFileThreshold:
- chunkSize = minChunkSize
- case fileSize <= mediumFileThreshold:
- chunkSize = 256 * 1024
- case fileSize <= largeFileThreshold:
- chunkSize = 1024 * 1024
- default:
- chunkSize = maxChunkSize
- }
-
- if chunkSize < minChunkSize {
- chunkSize = minChunkSize
- }
- if chunkSize > maxChunkSize {
- chunkSize = maxChunkSize
- }
- return chunkSize
-}
-
-func (server *ActionServer) Register(stream pb.CascadeService_RegisterServer) error {
- fields := logtrace.Fields{
- logtrace.FieldMethod: "Register",
- logtrace.FieldModule: "CascadeActionServer",
- }
-
- ctx := stream.Context()
- logtrace.Info(ctx, "register: stream open", fields)
-
- const maxFileSize = 1 * 1024 * 1024 * 1024 // 1GB limit
-
- var (
- metadata *pb.Metadata
- totalSize int
- )
-
- hasher, tempFile, tempFilePath, err := initializeHasherAndTempFile()
- if err != nil {
- fields[logtrace.FieldError] = err.Error()
- logtrace.Error(ctx, "failed to initialize hasher and temp file", fields)
- return fmt.Errorf("initializing hasher and temp file: %w", err)
- }
- defer func(tempFile *os.File) {
- err := tempFile.Close()
- if err != nil && !errors.Is(err, os.ErrClosed) {
- fields[logtrace.FieldError] = err.Error()
- logtrace.Warn(ctx, "error closing temp file", fields)
- }
- }(tempFile)
-
- for {
- req, err := stream.Recv()
- if err == io.EOF {
- break
- }
- if err != nil {
- fields[logtrace.FieldError] = err.Error()
- logtrace.Error(ctx, "error receiving stream data", fields)
- return fmt.Errorf("failed to receive stream data: %w", err)
- }
-
- switch x := req.RequestType.(type) {
- case *pb.RegisterRequest_Chunk:
- if x.Chunk != nil {
- if _, err := hasher.Write(x.Chunk.Data); err != nil {
- fields[logtrace.FieldError] = err.Error()
- logtrace.Error(ctx, "failed to write chunk to hasher", fields)
- return fmt.Errorf("hashing error: %w", err)
- }
- if _, err := tempFile.Write(x.Chunk.Data); err != nil {
- fields[logtrace.FieldError] = err.Error()
- logtrace.Error(ctx, "failed to write chunk to file", fields)
- return fmt.Errorf("file write error: %w", err)
- }
- totalSize += len(x.Chunk.Data)
- if totalSize > maxFileSize {
- fields[logtrace.FieldError] = "file size exceeds 1GB limit"
- fields["total_size"] = totalSize
- logtrace.Error(ctx, "upload rejected: file too large", fields)
- return fmt.Errorf("file size %d exceeds maximum allowed size of 1GB", totalSize)
- }
- // Keep chunk logs at debug to avoid verbosity
- logtrace.Debug(ctx, "received data chunk", logtrace.Fields{"chunk_size": len(x.Chunk.Data), "total_size_so_far": totalSize})
- }
- case *pb.RegisterRequest_Metadata:
- metadata = x.Metadata
- // Set correlation ID for the rest of the flow
- ctx = logtrace.CtxWithCorrelationID(ctx, metadata.ActionId)
- fields[logtrace.FieldTaskID] = metadata.GetTaskId()
- fields[logtrace.FieldActionID] = metadata.GetActionId()
- logtrace.Info(ctx, "register: metadata received", fields)
- // Start live task tracking on first metadata (covers remaining stream and processing)
- if !startedTask {
- startedTask = true
- handle = tasks.StartWith(server.tracker, ctx, serviceCascadeUpload, metadata.ActionId, server.uploadTimeout)
- defer handle.End(ctx)
- }
- }
- }
-
- if metadata == nil {
- logtrace.Error(ctx, "no metadata received in stream", fields)
- return fmt.Errorf("no metadata received")
- }
- fields[logtrace.FieldTaskID] = metadata.GetTaskId()
- fields[logtrace.FieldActionID] = metadata.GetActionId()
- logtrace.Info(ctx, "register: stream upload complete", fields)
-
- if err := tempFile.Sync(); err != nil {
- fields[logtrace.FieldError] = err.Error()
- logtrace.Error(ctx, "failed to sync temp file", fields)
- return fmt.Errorf("failed to sync temp file: %w", err)
- }
-
- hash := hasher.Sum(nil)
- hashHex := hex.EncodeToString(hash)
- fields[logtrace.FieldHashHex] = hashHex
- logtrace.Info(ctx, "register: hash computed", fields)
-
- targetPath, err := replaceTempDirWithTaskDir(metadata.GetTaskId(), tempFilePath, tempFile)
- if err != nil {
- fields[logtrace.FieldError] = err.Error()
- logtrace.Error(ctx, "failed to replace temp dir with task dir", fields)
- return fmt.Errorf("failed to replace temp dir with task dir: %w", err)
- }
-
- task := server.factory.NewCascadeRegistrationTask()
- logtrace.Info(ctx, "register: task start", fields)
- err = task.Register(ctx, &cascadeService.RegisterRequest{
- TaskID: metadata.TaskId,
- ActionID: metadata.ActionId,
- DataHash: hash,
- DataSize: totalSize,
- FilePath: targetPath,
- }, func(resp *cascadeService.RegisterResponse) error {
- grpcResp := &pb.RegisterResponse{
- EventType: pb.SupernodeEventType(resp.EventType),
- Message: resp.Message,
- TxHash: resp.TxHash,
- }
- if err := stream.Send(grpcResp); err != nil {
- logtrace.Error(ctx, "failed to send response to client", logtrace.Fields{logtrace.FieldError: err.Error()})
- return err
- }
- // Mirror event to Info logs for high-level tracing
- logtrace.Info(ctx, "register: event", logtrace.Fields{"event_type": resp.EventType, "message": resp.Message, logtrace.FieldTxHash: resp.TxHash, logtrace.FieldActionID: metadata.ActionId, logtrace.FieldTaskID: metadata.TaskId})
- return nil
- })
- if err != nil {
- logtrace.Error(ctx, "registration task failed", logtrace.Fields{logtrace.FieldError: err.Error()})
- return fmt.Errorf("registration failed: %w", err)
- }
- logtrace.Info(ctx, "register: task ok", fields)
- return nil
-}
-
-func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeService_DownloadServer) error {
- ctx := stream.Context()
- fields := logtrace.Fields{
- logtrace.FieldMethod: "Download",
- logtrace.FieldModule: "CascadeActionServer",
- logtrace.FieldActionID: req.GetActionId(),
- }
- logtrace.Debug(ctx, "download request received", fields)
-
- // Start live task tracking for the entire download RPC (including file streaming)
- dlHandle := tasks.StartWith(server.tracker, ctx, serviceCascadeDownload, req.GetActionId(), server.downloadTimeout)
- defer dlHandle.End(ctx)
-
- // Prepare to capture decoded file path from task events
- var decodedFilePath string
- var tmpDir string
-
- task := server.factory.NewCascadeRegistrationTask()
- // Run cascade task Download; stream events back to client
- err := task.Download(ctx, &cascadeService.DownloadRequest{ActionID: req.GetActionId(), Signature: req.GetSignature()}, func(resp *cascadeService.DownloadResponse) error {
- // Forward event to gRPC client
- evt := &pb.DownloadResponse{
- ResponseType: &pb.DownloadResponse_Event{
- Event: &pb.DownloadEvent{
- EventType: pb.SupernodeEventType(resp.EventType),
- Message: resp.Message,
- },
- },
- }
- if sendErr := stream.Send(evt); sendErr != nil {
- return sendErr
- }
- // Capture decode-completed info for streaming
- if resp.EventType == cascadeService.SupernodeEventTypeDecodeCompleted {
- decodedFilePath = resp.FilePath
- tmpDir = resp.DownloadedDir
- }
- return nil
- })
- if err != nil {
- fields[logtrace.FieldError] = err.Error()
- logtrace.Error(ctx, "download task failed", fields)
- return fmt.Errorf("download task failed: %w", err)
- }
-
- if decodedFilePath == "" {
- logtrace.Warn(ctx, "decode completed without file path", fields)
- return nil
- }
-
- // Notify client that server is ready to stream the file
- logtrace.Debug(ctx, "download: serve ready", logtrace.Fields{"event_type": cascadeService.SupernodeEventTypeServeReady, logtrace.FieldActionID: req.GetActionId()})
- if err := stream.Send(&pb.DownloadResponse{ResponseType: &pb.DownloadResponse_Event{Event: &pb.DownloadEvent{EventType: pb.SupernodeEventType_SERVE_READY, Message: "Serve ready"}}}); err != nil {
- return fmt.Errorf("send serve-ready: %w", err)
- }
-
- // Stream file content in chunks
- fi, err := os.Stat(decodedFilePath)
- if err != nil {
- return fmt.Errorf("stat decoded file: %w", err)
- }
- chunkSize := calculateOptimalChunkSize(fi.Size())
- f, err := os.Open(decodedFilePath)
- if err != nil {
- return fmt.Errorf("open decoded file: %w", err)
- }
- defer f.Close()
-
- buf := make([]byte, chunkSize)
- for {
- n, rerr := f.Read(buf)
- if n > 0 {
- if err := stream.Send(&pb.DownloadResponse{ResponseType: &pb.DownloadResponse_Chunk{Chunk: &pb.DataChunk{Data: append([]byte(nil), buf[:n]...)}}}); err != nil {
- return fmt.Errorf("send chunk: %w", err)
- }
- }
- if rerr == io.EOF {
- break
- }
- if rerr != nil {
- return fmt.Errorf("read decoded file: %w", rerr)
- }
- }
-
- // Cleanup temp directory if provided
- if tmpDir != "" {
- if cerr := task.CleanupDownload(ctx, tmpDir); cerr != nil {
- logtrace.Warn(ctx, "cleanup of tmp dir failed", logtrace.Fields{"tmp_dir": tmpDir, logtrace.FieldError: cerr.Error()})
- }
- }
-
- logtrace.Debug(ctx, "download stream completed", fields)
- return nil
-}
-
-// initializeHasherAndTempFile prepares a hasher and a temporary file to stream upload data into.
-func initializeHasherAndTempFile() (hash.Hash, *os.File, string, error) {
- // Create a temp directory for the upload
- tmpDir, err := os.MkdirTemp("", "supernode-upload-*")
- if err != nil {
- return nil, nil, "", fmt.Errorf("create temp dir: %w", err)
- }
-
- // Create a file within the temp directory
- filePath := filepath.Join(tmpDir, "data.bin")
- f, err := os.Create(filePath)
- if err != nil {
- return nil, nil, "", fmt.Errorf("create temp file: %w", err)
- }
-
- // Create a BLAKE3 hasher (32 bytes output)
- hasher := blake3.New(32, nil)
- return hasher, f, filePath, nil
-}
-
-// replaceTempDirWithTaskDir moves the uploaded file into a task-scoped directory
-// and returns the new absolute path.
-func replaceTempDirWithTaskDir(taskID, tempFilePath string, tempFile *os.File) (string, error) {
- // Ensure data is flushed
- _ = tempFile.Sync()
- // Close now; deferred close may run later and is safe to ignore
- _ = tempFile.Close()
-
- // Create a stable target directory under OS temp
- targetDir := filepath.Join(os.TempDir(), "supernode", "uploads", taskID)
- if err := os.MkdirAll(targetDir, 0700); err != nil {
- return "", fmt.Errorf("create task dir: %w", err)
- }
-
- newPath := filepath.Join(targetDir, filepath.Base(tempFilePath))
- if err := os.Rename(tempFilePath, newPath); err != nil {
- return "", fmt.Errorf("move uploaded file: %w", err)
- }
-
- // Attempt to cleanup the original temp directory
- _ = os.RemoveAll(filepath.Dir(tempFilePath))
- return newPath, nil
-}
diff --git a/supernode/transport/grpc/cascade/handler_test.go b/supernode/transport/grpc/cascade/handler_test.go
new file mode 100644
index 00000000..e457f444
--- /dev/null
+++ b/supernode/transport/grpc/cascade/handler_test.go
@@ -0,0 +1,106 @@
+package cascade
+
+import (
+ "context"
+ "io"
+ "os"
+ "testing"
+ "time"
+
+ pb "github.com/LumeraProtocol/supernode/v2/gen/supernode/action/cascade"
+ tasks "github.com/LumeraProtocol/supernode/v2/pkg/task"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/status"
+)
+
+type fakeRegisterStream struct {
+ ctx context.Context
+ reqs []*pb.RegisterRequest
+ i int
+}
+
+func (s *fakeRegisterStream) Send(*pb.RegisterResponse) error { return nil }
+
+func (s *fakeRegisterStream) Recv() (*pb.RegisterRequest, error) {
+ if s.i >= len(s.reqs) {
+ return nil, io.EOF
+ }
+ r := s.reqs[s.i]
+ s.i++
+ return r, nil
+}
+
+func (s *fakeRegisterStream) SetHeader(metadata.MD) error { return nil }
+func (s *fakeRegisterStream) SendHeader(metadata.MD) error { return nil }
+func (s *fakeRegisterStream) SetTrailer(metadata.MD) {}
+func (s *fakeRegisterStream) Context() context.Context {
+ if s.ctx != nil {
+ return s.ctx
+ }
+ return context.Background()
+}
+func (s *fakeRegisterStream) SendMsg(interface{}) error { return nil }
+func (s *fakeRegisterStream) RecvMsg(interface{}) error { return nil }
+
+func TestRegister_CleansTempDirOnHandlerError(t *testing.T) {
+ tmpRoot := t.TempDir()
+
+ prevTmpDir, hadPrevTmpDir := os.LookupEnv("TMPDIR")
+ t.Cleanup(func() {
+ if hadPrevTmpDir {
+ _ = os.Setenv("TMPDIR", prevTmpDir)
+ } else {
+ _ = os.Unsetenv("TMPDIR")
+ }
+ })
+ if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
+ t.Fatalf("set TMPDIR: %v", err)
+ }
+
+ server := &ActionServer{}
+ err := server.Register(&fakeRegisterStream{})
+ if err == nil {
+ t.Fatalf("expected error, got nil")
+ }
+
+ entries, rerr := os.ReadDir(tmpRoot)
+ if rerr != nil {
+ t.Fatalf("read tmpRoot: %v", rerr)
+ }
+ if len(entries) != 0 {
+ t.Fatalf("expected TMPDIR to be empty, found %d entries", len(entries))
+ }
+}
+
+func TestRegister_RejectsDuplicateActionID(t *testing.T) {
+ tmpRoot := t.TempDir()
+
+ prevTmpDir, hadPrevTmpDir := os.LookupEnv("TMPDIR")
+ t.Cleanup(func() {
+ if hadPrevTmpDir {
+ _ = os.Setenv("TMPDIR", prevTmpDir)
+ } else {
+ _ = os.Unsetenv("TMPDIR")
+ }
+ })
+ if err := os.Setenv("TMPDIR", tmpRoot); err != nil {
+ t.Fatalf("set TMPDIR: %v", err)
+ }
+
+ tr := tasks.New()
+ actionID := "action-1"
+ tr.Start(serviceCascadeUpload, actionID)
+
+ server := &ActionServer{tracker: tr, uploadTimeout: time.Second}
+ stream := &fakeRegisterStream{
+ reqs: []*pb.RegisterRequest{
+ {RequestType: &pb.RegisterRequest_Metadata{Metadata: &pb.Metadata{TaskId: "task-1", ActionId: actionID}}},
+ },
+ }
+
+ err := server.Register(stream)
+ if status.Code(err) != codes.AlreadyExists {
+ t.Fatalf("expected AlreadyExists, got %v", err)
+ }
+}
diff --git a/supernode/transport/grpc/cascade/register.go b/supernode/transport/grpc/cascade/register.go
new file mode 100644
index 00000000..6e6818cf
--- /dev/null
+++ b/supernode/transport/grpc/cascade/register.go
@@ -0,0 +1,222 @@
+package cascade
+
+import (
+ "encoding/hex"
+ "fmt"
+ "hash"
+ "io"
+ "os"
+ "path/filepath"
+
+ pb "github.com/LumeraProtocol/supernode/v2/gen/supernode/action/cascade"
+ "github.com/LumeraProtocol/supernode/v2/pkg/errors"
+ "github.com/LumeraProtocol/supernode/v2/pkg/logtrace"
+ tasks "github.com/LumeraProtocol/supernode/v2/pkg/task"
+ cascadeService "github.com/LumeraProtocol/supernode/v2/supernode/cascade"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+ "lukechampine.com/blake3"
+)
+
+func (server *ActionServer) Register(stream pb.CascadeService_RegisterServer) error {
+ fields := logtrace.Fields{
+ logtrace.FieldMethod: "Register",
+ logtrace.FieldModule: "CascadeActionServer",
+ }
+
+ ctx := stream.Context()
+ logtrace.Info(ctx, "register: stream open", fields)
+
+ const maxFileSize = 1 * 1024 * 1024 * 1024 // 1GB limit
+
+ var (
+ metadata *pb.Metadata
+ totalSize int
+ uploadHandle *tasks.Handle
+ )
+
+ hasher, tempFile, tempDir, tempFilePath, err := initializeHasherAndTempFile()
+ if err != nil {
+ fields[logtrace.FieldError] = err.Error()
+ logtrace.Error(ctx, "failed to initialize hasher and temp file", fields)
+ return fmt.Errorf("initializing hasher and temp file: %w", err)
+ }
+ defer func() {
+ if tempDir == "" {
+ return
+ }
+ if err := os.RemoveAll(tempDir); err != nil {
+ fields[logtrace.FieldError] = err.Error()
+ fields["temp_dir"] = tempDir
+ logtrace.Warn(ctx, "failed to cleanup upload temp dir", fields)
+ }
+ }()
+ defer func(tempFile *os.File) {
+ err := tempFile.Close()
+ if err != nil && !errors.Is(err, os.ErrClosed) {
+ fields[logtrace.FieldError] = err.Error()
+ logtrace.Warn(ctx, "error closing temp file", fields)
+ }
+ }(tempFile)
+
+ for {
+ req, err := stream.Recv()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ fields[logtrace.FieldError] = err.Error()
+ logtrace.Error(ctx, "error receiving stream data", fields)
+ return fmt.Errorf("failed to receive stream data: %w", err)
+ }
+
+ switch x := req.RequestType.(type) {
+ case *pb.RegisterRequest_Chunk:
+ if x.Chunk != nil {
+ if _, err := hasher.Write(x.Chunk.Data); err != nil {
+ fields[logtrace.FieldError] = err.Error()
+ logtrace.Error(ctx, "failed to write chunk to hasher", fields)
+ return fmt.Errorf("hashing error: %w", err)
+ }
+ if _, err := tempFile.Write(x.Chunk.Data); err != nil {
+ fields[logtrace.FieldError] = err.Error()
+ logtrace.Error(ctx, "failed to write chunk to file", fields)
+ return fmt.Errorf("file write error: %w", err)
+ }
+ totalSize += len(x.Chunk.Data)
+ if totalSize > maxFileSize {
+ fields[logtrace.FieldError] = "file size exceeds 1GB limit"
+ fields["total_size"] = totalSize
+ logtrace.Error(ctx, "upload rejected: file too large", fields)
+ return fmt.Errorf("file size %d exceeds maximum allowed size of 1GB", totalSize)
+ }
+ // Keep chunk logs at debug to avoid verbosity
+ logtrace.Debug(ctx, "received data chunk", logtrace.Fields{"chunk_size": len(x.Chunk.Data), "total_size_so_far": totalSize})
+ }
+ case *pb.RegisterRequest_Metadata:
+ metadata = x.Metadata
+ // Set correlation ID for the rest of the flow
+ ctx = logtrace.CtxWithCorrelationID(ctx, metadata.ActionId)
+ fields[logtrace.FieldTaskID] = metadata.GetTaskId()
+ fields[logtrace.FieldActionID] = metadata.GetActionId()
+ logtrace.Info(ctx, "register: metadata received", fields)
+ actionID := metadata.GetActionId()
+ if actionID == "" {
+ return status.Error(codes.InvalidArgument, "missing action_id")
+ }
+ // Start live task tracking on first metadata (covers remaining stream and processing).
+ // Track by ActionID to prevent duplicate in-flight uploads for the same action.
+ if uploadHandle == nil {
+ h, herr := tasks.StartUniqueWith(server.tracker, ctx, serviceCascadeUpload, actionID, server.uploadTimeout)
+ if herr != nil {
+ if errors.Is(herr, tasks.ErrAlreadyRunning) {
+ return status.Errorf(codes.AlreadyExists, "upload already in progress for %s", actionID)
+ }
+ return herr
+ }
+ uploadHandle = h
+ defer uploadHandle.End(ctx)
+ }
+ }
+ }
+
+ if metadata == nil {
+ logtrace.Error(ctx, "no metadata received in stream", fields)
+ return fmt.Errorf("no metadata received")
+ }
+ fields[logtrace.FieldTaskID] = metadata.GetTaskId()
+ fields[logtrace.FieldActionID] = metadata.GetActionId()
+ logtrace.Info(ctx, "register: stream upload complete", fields)
+
+ if err := tempFile.Sync(); err != nil {
+ fields[logtrace.FieldError] = err.Error()
+ logtrace.Error(ctx, "failed to sync temp file", fields)
+ return fmt.Errorf("failed to sync temp file: %w", err)
+ }
+
+ hash := hasher.Sum(nil)
+ hashHex := hex.EncodeToString(hash)
+ fields[logtrace.FieldHashHex] = hashHex
+ logtrace.Info(ctx, "register: hash computed", fields)
+
+ targetPath, err := replaceTempDirWithTaskDir(metadata.GetTaskId(), tempFilePath, tempFile)
+ if err != nil {
+ fields[logtrace.FieldError] = err.Error()
+ logtrace.Error(ctx, "failed to replace temp dir with task dir", fields)
+ return fmt.Errorf("failed to replace temp dir with task dir: %w", err)
+ }
+
+ task := server.factory.NewCascadeRegistrationTask()
+ logtrace.Info(ctx, "register: task start", fields)
+ err = task.Register(ctx, &cascadeService.RegisterRequest{
+ TaskID: metadata.TaskId,
+ ActionID: metadata.ActionId,
+ DataHash: hash,
+ DataSize: totalSize,
+ FilePath: targetPath,
+ }, func(resp *cascadeService.RegisterResponse) error {
+ grpcResp := &pb.RegisterResponse{
+ EventType: pb.SupernodeEventType(resp.EventType),
+ Message: resp.Message,
+ TxHash: resp.TxHash,
+ }
+ if err := stream.Send(grpcResp); err != nil {
+ logtrace.Error(ctx, "failed to send response to client", logtrace.Fields{logtrace.FieldError: err.Error()})
+ return err
+ }
+ // Mirror event to Info logs for high-level tracing
+ logtrace.Info(ctx, "register: event", logtrace.Fields{"event_type": resp.EventType, "message": resp.Message, logtrace.FieldTxHash: resp.TxHash, logtrace.FieldActionID: metadata.ActionId, logtrace.FieldTaskID: metadata.TaskId})
+ return nil
+ })
+ if err != nil {
+ logtrace.Error(ctx, "registration task failed", logtrace.Fields{logtrace.FieldError: err.Error()})
+ return fmt.Errorf("registration failed: %w", err)
+ }
+ logtrace.Info(ctx, "register: task ok", fields)
+ return nil
+}
+
+// initializeHasherAndTempFile prepares a hasher and a temporary file to stream upload data into.
+func initializeHasherAndTempFile() (hash.Hash, *os.File, string, string, error) {
+ // Create a temp directory for the upload
+ tmpDir, err := os.MkdirTemp("", "supernode-upload-*")
+ if err != nil {
+ return nil, nil, "", "", fmt.Errorf("create temp dir: %w", err)
+ }
+
+ // Create a file within the temp directory
+ filePath := filepath.Join(tmpDir, "data.bin")
+ f, err := os.Create(filePath)
+ if err != nil {
+ _ = os.RemoveAll(tmpDir)
+ return nil, nil, "", "", fmt.Errorf("create temp file: %w", err)
+ }
+
+ // Create a BLAKE3 hasher (32 bytes output)
+ hasher := blake3.New(32, nil)
+ return hasher, f, tmpDir, filePath, nil
+}
+
+// replaceTempDirWithTaskDir moves the uploaded file into a task-scoped directory
+// and returns the new absolute path.
+func replaceTempDirWithTaskDir(taskID, tempFilePath string, tempFile *os.File) (string, error) {
+ // Ensure data is flushed
+ _ = tempFile.Sync()
+ // Close now; deferred close may run later and is safe to ignore
+ _ = tempFile.Close()
+
+ // Create a stable target directory under OS temp
+ targetDir := filepath.Join(os.TempDir(), "supernode", "uploads", taskID)
+ if err := os.MkdirAll(targetDir, 0700); err != nil {
+ return "", fmt.Errorf("create task dir: %w", err)
+ }
+
+ newPath := filepath.Join(targetDir, filepath.Base(tempFilePath))
+ if err := os.Rename(tempFilePath, newPath); err != nil {
+ return "", fmt.Errorf("move uploaded file: %w", err)
+ }
+
+ // Attempt to cleanup the original temp directory
+ _ = os.RemoveAll(filepath.Dir(tempFilePath))
+ return newPath, nil
+}