diff --git a/pkg/codec/decode.go b/pkg/codec/decode.go index 251f92c4..81c84e7f 100644 --- a/pkg/codec/decode.go +++ b/pkg/codec/decode.go @@ -13,6 +13,7 @@ import ( raptorq "github.com/LumeraProtocol/rq-go" "github.com/LumeraProtocol/supernode/v2/pkg/errors" "github.com/LumeraProtocol/supernode/v2/pkg/logtrace" + "github.com/google/uuid" ) type DecodeRequest struct { @@ -29,7 +30,7 @@ type DecodeResponse struct { // Workspace holds paths & reverse index for prepared decoding. type Workspace struct { ActionID string - SymbolsDir string // ...// + SymbolsDir string // ...//downloads// BlockDirs []string // index = blockID (or 0 if single block) symbolToBlock map[string]int mu sync.RWMutex // protects symbolToBlock reads if you expand it later @@ -51,8 +52,12 @@ func (rq *raptorQ) PrepareDecode( } logtrace.Info(ctx, "rq: prepare-decode start", fields) - // Create root symbols dir for this action - symbolsDir := filepath.Join(rq.symbolsBaseDir, actionID) + // Create per-request workspace under /downloads// + base := rq.symbolsBaseDir + if base == "" { + base = os.TempDir() + } + symbolsDir := filepath.Join(base, "downloads", actionID, uuid.NewString()) if err := os.MkdirAll(symbolsDir, 0o755); err != nil { fields[logtrace.FieldError] = err.Error() logtrace.Error(ctx, "mkdir symbols base dir failed", fields) diff --git a/pkg/codec/decode_workspace_test.go b/pkg/codec/decode_workspace_test.go new file mode 100644 index 00000000..4a8a664c --- /dev/null +++ b/pkg/codec/decode_workspace_test.go @@ -0,0 +1,51 @@ +package codec + +import ( + "context" + "os" + "testing" +) + +func TestPrepareDecode_UniqueWorkspacePerCall(t *testing.T) { + base := t.TempDir() + c := NewRaptorQCodec(base) + layout := Layout{Blocks: []Block{{BlockID: 0, Symbols: []string{"s1"}}}} + + _, _, cleanup1, ws1, err := c.PrepareDecode(context.Background(), "actionA", layout) + if err != nil { + t.Fatalf("prepare decode 1: %v", err) + } + t.Cleanup(func() { _ = cleanup1() }) + if ws1 == nil || ws1.SymbolsDir == "" { + t.Fatalf("prepare decode 1 returned empty workspace") + } + if _, err := os.Stat(ws1.SymbolsDir); err != nil { + t.Fatalf("stat ws1: %v", err) + } + + _, _, cleanup2, ws2, err := c.PrepareDecode(context.Background(), "actionA", layout) + if err != nil { + t.Fatalf("prepare decode 2: %v", err) + } + t.Cleanup(func() { _ = cleanup2() }) + if ws2 == nil || ws2.SymbolsDir == "" { + t.Fatalf("prepare decode 2 returned empty workspace") + } + if _, err := os.Stat(ws2.SymbolsDir); err != nil { + t.Fatalf("stat ws2: %v", err) + } + + if ws1.SymbolsDir == ws2.SymbolsDir { + t.Fatalf("expected unique workspace per call; got same dir: %s", ws1.SymbolsDir) + } + + if err := cleanup1(); err != nil { + t.Fatalf("cleanup 1: %v", err) + } + if _, err := os.Stat(ws1.SymbolsDir); !os.IsNotExist(err) { + t.Fatalf("expected ws1 removed; stat err=%v", err) + } + if _, err := os.Stat(ws2.SymbolsDir); err != nil { + t.Fatalf("expected ws2 still present after ws1 cleanup; stat err=%v", err) + } +} diff --git a/pkg/task/handle.go b/pkg/task/handle.go index 74f6e406..4cc03090 100644 --- a/pkg/task/handle.go +++ b/pkg/task/handle.go @@ -2,12 +2,15 @@ package task import ( "context" + "errors" "sync" "time" "github.com/LumeraProtocol/supernode/v2/pkg/logtrace" ) +var ErrAlreadyRunning = errors.New("task already running") + // Handle manages a running task with an optional watchdog. // It ensures Start and End are paired, logs start/end, and auto-ends on timeout. type Handle struct { @@ -41,6 +44,37 @@ func StartWith(tr Tracker, ctx context.Context, service, id string, timeout time return g } +// StartUniqueWith starts tracking a task only if it's not already tracked for the same +// (service, id) pair. It returns ErrAlreadyRunning if the task is already in-flight. +func StartUniqueWith(tr Tracker, ctx context.Context, service, id string, timeout time.Duration) (*Handle, error) { + if tr == nil || service == "" || id == "" { + return &Handle{}, nil + } + + if ts, ok := tr.(interface { + TryStart(service, taskID string) bool + }); ok { + if !ts.TryStart(service, id) { + return nil, ErrAlreadyRunning + } + } else { // fallback: can't enforce uniqueness with unknown Tracker implementations + tr.Start(service, id) + } + + logtrace.Info(ctx, "task: started", logtrace.Fields{"service": service, "task_id": id}) + g := &Handle{tr: tr, service: service, id: id, stop: make(chan struct{})} + if timeout > 0 { + go func() { + select { + case <-time.After(timeout): + g.endWith(ctx, true) + case <-g.stop: + } + }() + } + return g, nil +} + // End stops tracking the task. Safe to call multiple times. func (g *Handle) End(ctx context.Context) { g.endWith(ctx, false) diff --git a/pkg/task/task.go b/pkg/task/task.go index 8d0c0052..6b2f98a3 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -30,6 +30,28 @@ func New() *InMemoryTracker { return &InMemoryTracker{data: make(map[string]map[string]struct{})} } +// TryStart attempts to mark a task as running under a given service. +// It returns true if the task was newly started, or false if it was already running +// (or if inputs are invalid). This is useful for "only one in-flight task" guards. +func (t *InMemoryTracker) TryStart(service, taskID string) bool { + if service == "" || taskID == "" { + return false + } + t.mu.Lock() + m, ok := t.data[service] + if !ok { + m = make(map[string]struct{}) + t.data[service] = m + } + if _, exists := m[taskID]; exists { + t.mu.Unlock() + return false + } + m[taskID] = struct{}{} + t.mu.Unlock() + return true +} + // Start marks a task as running under a given service. Empty arguments // are ignored. Calling Start with the same (service, taskID) pair is idempotent. func (t *InMemoryTracker) Start(service, taskID string) { diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go index 1550bc37..0e1c660a 100644 --- a/pkg/task/task_test.go +++ b/pkg/task/task_test.go @@ -2,6 +2,7 @@ package task import ( "context" + "errors" "sync" "testing" "time" @@ -155,3 +156,27 @@ func TestHandleIdempotentAndWatchdog(t *testing.T) { } } } + +func TestStartUniqueWith_PreventsDuplicates(t *testing.T) { + tr := New() + ctx := context.Background() + + h1, err := StartUniqueWith(tr, ctx, "svc.unique", "id-1", 0) + if err != nil { + t.Fatalf("StartUniqueWith 1: %v", err) + } + t.Cleanup(func() { h1.End(ctx) }) + + h2, err := StartUniqueWith(tr, ctx, "svc.unique", "id-1", 0) + if !errors.Is(err, ErrAlreadyRunning) { + t.Fatalf("expected ErrAlreadyRunning, got handle=%v err=%v", h2, err) + } + + // After ending, it should be startable again. + h1.End(ctx) + h3, err := StartUniqueWith(tr, ctx, "svc.unique", "id-1", 0) + if err != nil { + t.Fatalf("StartUniqueWith 2: %v", err) + } + h3.End(ctx) +} diff --git a/supernode/adaptors/p2p.go b/supernode/adaptors/p2p.go index 31184fd7..ce218a4f 100644 --- a/supernode/adaptors/p2p.go +++ b/supernode/adaptors/p2p.go @@ -134,7 +134,7 @@ func (p *p2pImpl) storeCascadeSymbolsAndData(ctx context.Context, taskID, action } start = end } - if err := p.rqStore.UpdateIsFirstBatchStored(actionID); err != nil { + if err := p.rqStore.UpdateIsFirstBatchStored(taskID); err != nil { return totalSymbols, totalAvailable, fmt.Errorf("update first-batch flag: %w", err) } return totalSymbols, totalAvailable, nil diff --git a/supernode/cascade/download.go b/supernode/cascade/download.go index aa61a82a..02ab879d 100644 --- a/supernode/cascade/download.go +++ b/supernode/cascade/download.go @@ -48,7 +48,9 @@ func (task *CascadeRegistrationTask) Download(ctx context.Context, req *Download return task.wrapErr(ctx, "failed to get action", err, fields) } logtrace.Info(ctx, "download: action fetched", fields) - task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "Action retrieved", "", "", send) + if err := task.streamDownloadEvent(ctx, SupernodeEventTypeActionRetrieved, "Action retrieved", "", "", send); err != nil { + return err + } if actionDetails.GetAction().State != actiontypes.ActionStateDone { err = errors.New("action is not in a valid state") @@ -64,7 +66,9 @@ func (task *CascadeRegistrationTask) Download(ctx context.Context, req *Download return task.wrapErr(ctx, "error decoding cascade metadata", err, fields) } logtrace.Info(ctx, "download: metadata decoded", fields) - task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", "", send) + if err := task.streamDownloadEvent(ctx, SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", "", send); err != nil { + return err + } if !metadata.Public { if req.Signature == "" { @@ -80,7 +84,9 @@ func (task *CascadeRegistrationTask) Download(ctx context.Context, req *Download logtrace.Info(ctx, "download: public cascade (no signature)", fields) } - task.streamDownloadEvent(SupernodeEventTypeNetworkRetrieveStarted, "Network retrieval started", "", "", send) + if err := task.streamDownloadEvent(ctx, SupernodeEventTypeNetworkRetrieveStarted, "Network retrieval started", "", "", send); err != nil { + return err + } logtrace.Info(ctx, "download: network retrieval start", logtrace.Fields{logtrace.FieldActionID: actionDetails.GetAction().ActionID}) filePath, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields, send) @@ -91,10 +97,20 @@ func (task *CascadeRegistrationTask) Download(ctx context.Context, req *Download logtrace.Warn(ctx, "cleanup of tmp dir after error failed", logtrace.Fields{"tmp_dir": tmpDir, logtrace.FieldError: cerr.Error()}) } } + if ctx.Err() != nil { + return ctx.Err() + } return task.wrapErr(ctx, "failed to download artifacts", err, fields) } logtrace.Debug(ctx, "File reconstructed and hash verified", fields) - task.streamDownloadEvent(SupernodeEventTypeDecodeCompleted, "Decode completed", filePath, tmpDir, send) + if err := task.streamDownloadEvent(ctx, SupernodeEventTypeDecodeCompleted, "Decode completed", filePath, tmpDir, send); err != nil { + if tmpDir != "" { + if cerr := task.CleanupDownload(ctx, tmpDir); cerr != nil { + logtrace.Warn(ctx, "cleanup of tmp dir after stream failure failed", logtrace.Fields{"tmp_dir": tmpDir, logtrace.FieldError: cerr.Error()}) + } + } + return err + } return nil } @@ -127,8 +143,11 @@ func (task *CascadeRegistrationTask) VerifyDownloadSignature(ctx context.Context return nil } -func (task *CascadeRegistrationTask) streamDownloadEvent(eventType SupernodeEventType, msg, filePath, dir string, send func(resp *DownloadResponse) error) { - _ = send(&DownloadResponse{EventType: eventType, Message: msg, FilePath: filePath, DownloadedDir: dir}) +func (task *CascadeRegistrationTask) streamDownloadEvent(ctx context.Context, eventType SupernodeEventType, msg, filePath, dir string, send func(resp *DownloadResponse) error) error { + if err := ctx.Err(); err != nil { + return err + } + return send(&DownloadResponse{EventType: eventType, Message: msg, FilePath: filePath, DownloadedDir: dir}) } func (task *CascadeRegistrationTask) downloadArtifacts(ctx context.Context, actionID string, metadata actiontypes.CascadeMetadata, fields logtrace.Fields, send func(resp *DownloadResponse) error) (string, string, error) { @@ -244,7 +263,9 @@ func (task *CascadeRegistrationTask) restoreFileFromLayoutDeprecated(ctx context // Emit minimal JSON payload (metrics system removed) info := map[string]interface{}{"action_id": actionID, "found_symbols": len(symbols), "target_percent": targetRequiredPercent} if b, err := json.Marshal(info); err == nil { - task.streamDownloadEvent(SupernodeEventTypeArtefactsDownloaded, string(b), decodeInfo.FilePath, decodeInfo.DecodeTmpDir, send) + if err := task.streamDownloadEvent(ctx, SupernodeEventTypeArtefactsDownloaded, string(b), decodeInfo.FilePath, decodeInfo.DecodeTmpDir, send); err != nil { + return "", decodeInfo.DecodeTmpDir, err + } } return decodeInfo.FilePath, decodeInfo.DecodeTmpDir, nil } @@ -399,7 +420,9 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout( // Event info := map[string]interface{}{"action_id": actionID, "found_symbols": int(atomic.LoadInt32(&written)), "target_percent": targetRequiredPercent} if b, err := json.Marshal(info); err == nil { - task.streamDownloadEvent(SupernodeEventTypeArtefactsDownloaded, string(b), decodeInfo.FilePath, decodeInfo.DecodeTmpDir, send) + if err := task.streamDownloadEvent(ctx, SupernodeEventTypeArtefactsDownloaded, string(b), decodeInfo.FilePath, decodeInfo.DecodeTmpDir, send); err != nil { + return "", decodeInfo.DecodeTmpDir, err + } } success = true diff --git a/supernode/cascade/helper.go b/supernode/cascade/helper.go index 394b59f1..b5537d24 100644 --- a/supernode/cascade/helper.go +++ b/supernode/cascade/helper.go @@ -165,13 +165,13 @@ func (task *CascadeRegistrationTask) wrapErr(ctx context.Context, msg string, er return status.Errorf(codes.Internal, "%s", msg) } -func (task *CascadeRegistrationTask) emitArtefactsStored(ctx context.Context, fields logtrace.Fields, _ codec.Layout, send func(resp *RegisterResponse) error) { +func (task *CascadeRegistrationTask) emitArtefactsStored(ctx context.Context, fields logtrace.Fields, _ codec.Layout, send func(resp *RegisterResponse) error) error { if fields == nil { fields = logtrace.Fields{} } msg := "Artefacts stored" logtrace.Info(ctx, "register: artefacts stored", fields) - task.streamEvent(SupernodeEventTypeArtefactsStored, msg, "", send) + return task.streamEvent(ctx, SupernodeEventTypeArtefactsStored, msg, "", send) } func (task *CascadeRegistrationTask) verifyActionFee(ctx context.Context, action *actiontypes.Action, dataSize int, fields logtrace.Fields) error { diff --git a/supernode/cascade/register.go b/supernode/cascade/register.go index 926f9b31..1693d61f 100644 --- a/supernode/cascade/register.go +++ b/supernode/cascade/register.go @@ -59,14 +59,18 @@ func (task *CascadeRegistrationTask) Register( fields[logtrace.FieldStatus] = action.State fields[logtrace.FieldPrice] = action.Price logtrace.Info(ctx, "register: action fetched", fields) - task.streamEvent(SupernodeEventTypeActionRetrieved, "Action retrieved", "", send) + if err := task.streamEvent(ctx, SupernodeEventTypeActionRetrieved, "Action retrieved", "", send); err != nil { + return err + } // Step 4: Verify action fee based on data size (rounded up to KB) if err := task.verifyActionFee(ctx, action, req.DataSize, fields); err != nil { return err } logtrace.Info(ctx, "register: fee verified", fields) - task.streamEvent(SupernodeEventTypeActionFeeVerified, "Action fee verified", "", send) + if err := task.streamEvent(ctx, SupernodeEventTypeActionFeeVerified, "Action fee verified", "", send); err != nil { + return err + } // Step 5: Ensure this node is eligible (top supernode for block) fields[logtrace.FieldSupernodeState] = task.SupernodeAccountAddress @@ -74,7 +78,9 @@ func (task *CascadeRegistrationTask) Register( return err } logtrace.Info(ctx, "register: top supernode confirmed", fields) - task.streamEvent(SupernodeEventTypeTopSupernodeCheckPassed, "Top supernode eligibility confirmed", "", send) + if err := task.streamEvent(ctx, SupernodeEventTypeTopSupernodeCheckPassed, "Top supernode eligibility confirmed", "", send); err != nil { + return err + } // Step 6: Decode Cascade metadata from the action cascadeMeta, err := cascadekit.UnmarshalCascadeMetadata(action.Metadata) @@ -82,7 +88,9 @@ func (task *CascadeRegistrationTask) Register( return task.wrapErr(ctx, "failed to unmarshal cascade metadata", err, fields) } logtrace.Info(ctx, "register: metadata decoded", fields) - task.streamEvent(SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", send) + if err := task.streamEvent(ctx, SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", send); err != nil { + return err + } // Step 7: Verify request-provided data hash matches metadata if err := cascadekit.VerifyB64DataHash(req.DataHash, cascadeMeta.DataHash); err != nil { @@ -90,7 +98,9 @@ func (task *CascadeRegistrationTask) Register( } logtrace.Debug(ctx, "request data-hash has been matched with the action data-hash", fields) logtrace.Info(ctx, "register: data hash matched", fields) - task.streamEvent(SupernodeEventTypeDataHashVerified, "Data hash verified", "", send) + if err := task.streamEvent(ctx, SupernodeEventTypeDataHashVerified, "Data hash verified", "", send); err != nil { + return err + } // Step 8: Encode input using the RQ codec to produce layout and symbols encodeResult, err := task.encodeInput(ctx, req.ActionID, req.FilePath, fields) @@ -99,7 +109,9 @@ func (task *CascadeRegistrationTask) Register( } fields["symbols_dir"] = encodeResult.SymbolsDir logtrace.Info(ctx, "register: input encoded", fields) - task.streamEvent(SupernodeEventTypeInputEncoded, "Input encoded", "", send) + if err := task.streamEvent(ctx, SupernodeEventTypeInputEncoded, "Input encoded", "", send); err != nil { + return err + } // Step 9: Verify index and layout signatures; produce layoutB64 logtrace.Info(ctx, "register: verify+decode layout start", fields) @@ -109,7 +121,9 @@ func (task *CascadeRegistrationTask) Register( } layoutSignatureB64 := indexFile.LayoutSignature logtrace.Info(ctx, "register: signature verified", fields) - task.streamEvent(SupernodeEventTypeSignatureVerified, "Signature verified", "", send) + if err := task.streamEvent(ctx, SupernodeEventTypeSignatureVerified, "Signature verified", "", send); err != nil { + return err + } // Step 10: Generate RQID files (layout and index) and compute IDs rqIDs, idFiles, err := task.generateRQIDFiles(ctx, cascadeMeta, layoutSignatureB64, layoutB64, fields) @@ -129,26 +143,36 @@ func (task *CascadeRegistrationTask) Register( fields["combined_files_size_kb"] = float64(totalSize) / 1024 fields["combined_files_size_mb"] = float64(totalSize) / (1024 * 1024) logtrace.Info(ctx, "register: rqid files generated", fields) - task.streamEvent(SupernodeEventTypeRQIDsGenerated, "RQID files generated", "", send) + if err := task.streamEvent(ctx, SupernodeEventTypeRQIDsGenerated, "RQID files generated", "", send); err != nil { + return err + } logtrace.Info(ctx, "register: rqids validated", fields) - task.streamEvent(SupernodeEventTypeRqIDsVerified, "RQIDs verified", "", send) + if err := task.streamEvent(ctx, SupernodeEventTypeRqIDsVerified, "RQIDs verified", "", send); err != nil { + return err + } // Step 11: Simulate finalize to ensure the tx will succeed if _, err := task.LumeraClient.SimulateFinalizeAction(ctx, action.ActionID, rqIDs); err != nil { fields[logtrace.FieldError] = err.Error() logtrace.Info(ctx, "register: finalize simulation failed", fields) - task.streamEvent(SupernodeEventTypeFinalizeSimulationFailed, "Finalize simulation failed", "", send) + if err := task.streamEvent(ctx, SupernodeEventTypeFinalizeSimulationFailed, "Finalize simulation failed", "", send); err != nil { + return err + } return task.wrapErr(ctx, "finalize action simulation failed", err, fields) } logtrace.Info(ctx, "register: finalize simulation passed", fields) - task.streamEvent(SupernodeEventTypeFinalizeSimulated, "Finalize simulation passed", "", send) + if err := task.streamEvent(ctx, SupernodeEventTypeFinalizeSimulated, "Finalize simulation passed", "", send); err != nil { + return err + } // Step 12: Store artefacts to the network store if err := task.storeArtefacts(ctx, action.ActionID, idFiles, encodeResult.SymbolsDir, fields); err != nil { return err } - task.emitArtefactsStored(ctx, fields, encodeResult.Layout, send) + if err := task.emitArtefactsStored(ctx, fields, encodeResult.Layout, send); err != nil { + return err + } // Step 13: Finalize the action on-chain resp, err := task.LumeraClient.FinalizeAction(ctx, action.ActionID, rqIDs) @@ -160,6 +184,8 @@ func (task *CascadeRegistrationTask) Register( txHash := resp.TxResponse.TxHash fields[logtrace.FieldTxHash] = txHash logtrace.Info(ctx, "register: action finalized", fields) - task.streamEvent(SupernodeEventTypeActionFinalized, "Action finalized", txHash, send) + if err := task.streamEvent(ctx, SupernodeEventTypeActionFinalized, "Action finalized", txHash, send); err != nil { + return err + } return nil } diff --git a/supernode/cascade/stream_send_error_test.go b/supernode/cascade/stream_send_error_test.go new file mode 100644 index 00000000..e1b32e7a --- /dev/null +++ b/supernode/cascade/stream_send_error_test.go @@ -0,0 +1,83 @@ +package cascade + +import ( + "context" + "errors" + "testing" + + actiontypes "github.com/LumeraProtocol/lumera/x/action/v1/types" + sntypes "github.com/LumeraProtocol/lumera/x/supernode/v1/types" + sdktx "github.com/cosmos/cosmos-sdk/types/tx" +) + +type stubLumeraClient struct { + action *actiontypes.Action +} + +func (s *stubLumeraClient) GetAction(_ context.Context, _ string) (*actiontypes.QueryGetActionResponse, error) { + return &actiontypes.QueryGetActionResponse{Action: s.action}, nil +} + +func (s *stubLumeraClient) GetTopSupernodes(context.Context, uint64) (*sntypes.QueryGetTopSuperNodesForBlockResponse, error) { + panic("unexpected call") +} + +func (s *stubLumeraClient) Verify(context.Context, string, []byte, []byte) error { + panic("unexpected call") +} + +func (s *stubLumeraClient) GetActionFee(context.Context, string) (*actiontypes.QueryGetActionFeeResponse, error) { + panic("unexpected call") +} + +func (s *stubLumeraClient) SimulateFinalizeAction(context.Context, string, []string) (*sdktx.SimulateResponse, error) { + panic("unexpected call") +} + +func (s *stubLumeraClient) FinalizeAction(context.Context, string, []string) (*sdktx.BroadcastTxResponse, error) { + panic("unexpected call") +} + +func TestRegister_AbortsOnEventSendError(t *testing.T) { + sendErr := errors.New("send failed") + sendCalls := 0 + + service := &CascadeService{ + LumeraClient: &stubLumeraClient{action: &actiontypes.Action{ActionID: "action123"}}, + } + task := NewCascadeRegistrationTask(service) + + err := task.Register(context.Background(), &RegisterRequest{TaskID: "task123", ActionID: "action123"}, func(*RegisterResponse) error { + sendCalls++ + return sendErr + }) + + if !errors.Is(err, sendErr) { + t.Fatalf("expected send error; got %v", err) + } + if sendCalls != 1 { + t.Fatalf("expected 1 send call, got %d", sendCalls) + } +} + +func TestDownload_AbortsOnEventSendError(t *testing.T) { + sendErr := errors.New("send failed") + sendCalls := 0 + + service := &CascadeService{ + LumeraClient: &stubLumeraClient{action: &actiontypes.Action{ActionID: "action123", State: actiontypes.ActionStateDone}}, + } + task := NewCascadeRegistrationTask(service) + + err := task.Download(context.Background(), &DownloadRequest{ActionID: "action123"}, func(*DownloadResponse) error { + sendCalls++ + return sendErr + }) + + if !errors.Is(err, sendErr) { + t.Fatalf("expected send error; got %v", err) + } + if sendCalls != 1 { + t.Fatalf("expected 1 send call, got %d", sendCalls) + } +} diff --git a/supernode/cascade/task.go b/supernode/cascade/task.go index 71725d20..8539eee7 100644 --- a/supernode/cascade/task.go +++ b/supernode/cascade/task.go @@ -1,5 +1,7 @@ package cascade +import "context" + // CascadeRegistrationTask is the task for cascade registration type CascadeRegistrationTask struct { *CascadeService @@ -15,6 +17,10 @@ func NewCascadeRegistrationTask(service *CascadeService) *CascadeRegistrationTas } // streamEvent sends a RegisterResponse via the provided callback. -func (task *CascadeRegistrationTask) streamEvent(eventType SupernodeEventType, msg, txHash string, send func(resp *RegisterResponse) error) { - _ = send(&RegisterResponse{EventType: eventType, Message: msg, TxHash: txHash}) +// It propagates send failures so callers can abort work when the downstream is gone. +func (task *CascadeRegistrationTask) streamEvent(ctx context.Context, eventType SupernodeEventType, msg, txHash string, send func(resp *RegisterResponse) error) error { + if err := ctx.Err(); err != nil { + return err + } + return send(&RegisterResponse{EventType: eventType, Message: msg, TxHash: txHash}) } diff --git a/supernode/transport/grpc/cascade/chunksize.go b/supernode/transport/grpc/cascade/chunksize.go new file mode 100644 index 00000000..86cda1d8 --- /dev/null +++ b/supernode/transport/grpc/cascade/chunksize.go @@ -0,0 +1,34 @@ +package cascade + +// calculateOptimalChunkSize returns an optimal chunk size based on file size +// to balance throughput and memory usage +func calculateOptimalChunkSize(fileSize int64) int { + const ( + minChunkSize = 64 * 1024 // 64 KB minimum + maxChunkSize = 4 * 1024 * 1024 // 4 MB maximum for 1GB+ files + smallFileThreshold = 1024 * 1024 // 1 MB + mediumFileThreshold = 50 * 1024 * 1024 // 50 MB + largeFileThreshold = 500 * 1024 * 1024 // 500 MB + ) + + var chunkSize int + + switch { + case fileSize <= smallFileThreshold: + chunkSize = minChunkSize + case fileSize <= mediumFileThreshold: + chunkSize = 256 * 1024 + case fileSize <= largeFileThreshold: + chunkSize = 1024 * 1024 + default: + chunkSize = maxChunkSize + } + + if chunkSize < minChunkSize { + chunkSize = minChunkSize + } + if chunkSize > maxChunkSize { + chunkSize = maxChunkSize + } + return chunkSize +} diff --git a/supernode/transport/grpc/cascade/download.go b/supernode/transport/grpc/cascade/download.go new file mode 100644 index 00000000..7e25b23e --- /dev/null +++ b/supernode/transport/grpc/cascade/download.go @@ -0,0 +1,108 @@ +package cascade + +import ( + "fmt" + "io" + "os" + + pb "github.com/LumeraProtocol/supernode/v2/gen/supernode/action/cascade" + "github.com/LumeraProtocol/supernode/v2/pkg/logtrace" + tasks "github.com/LumeraProtocol/supernode/v2/pkg/task" + cascadeService "github.com/LumeraProtocol/supernode/v2/supernode/cascade" +) + +func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeService_DownloadServer) error { + ctx := stream.Context() + fields := logtrace.Fields{ + logtrace.FieldMethod: "Download", + logtrace.FieldModule: "CascadeActionServer", + logtrace.FieldActionID: req.GetActionId(), + } + logtrace.Debug(ctx, "download request received", fields) + + // Start live task tracking for the entire download RPC (including file streaming) + dlHandle := tasks.StartWith(server.tracker, ctx, serviceCascadeDownload, req.GetActionId(), server.downloadTimeout) + defer dlHandle.End(ctx) + + // Prepare to capture decoded file path from task events + var decodedFilePath string + var tmpDir string + + task := server.factory.NewCascadeRegistrationTask() + defer func() { + if tmpDir == "" { + return + } + if cerr := task.CleanupDownload(ctx, tmpDir); cerr != nil { + logtrace.Warn(ctx, "cleanup of tmp dir failed", logtrace.Fields{"tmp_dir": tmpDir, logtrace.FieldError: cerr.Error()}) + } + }() + // Run cascade task Download; stream events back to client + err := task.Download(ctx, &cascadeService.DownloadRequest{ActionID: req.GetActionId(), Signature: req.GetSignature()}, func(resp *cascadeService.DownloadResponse) error { + // Forward event to gRPC client + evt := &pb.DownloadResponse{ + ResponseType: &pb.DownloadResponse_Event{ + Event: &pb.DownloadEvent{ + EventType: pb.SupernodeEventType(resp.EventType), + Message: resp.Message, + }, + }, + } + if sendErr := stream.Send(evt); sendErr != nil { + return sendErr + } + // Capture decode-completed info for streaming + if resp.EventType == cascadeService.SupernodeEventTypeDecodeCompleted { + decodedFilePath = resp.FilePath + tmpDir = resp.DownloadedDir + } + return nil + }) + if err != nil { + fields[logtrace.FieldError] = err.Error() + logtrace.Error(ctx, "download task failed", fields) + return fmt.Errorf("download task failed: %w", err) + } + + if decodedFilePath == "" { + logtrace.Warn(ctx, "decode completed without file path", fields) + return nil + } + + // Notify client that server is ready to stream the file + logtrace.Debug(ctx, "download: serve ready", logtrace.Fields{"event_type": cascadeService.SupernodeEventTypeServeReady, logtrace.FieldActionID: req.GetActionId()}) + if err := stream.Send(&pb.DownloadResponse{ResponseType: &pb.DownloadResponse_Event{Event: &pb.DownloadEvent{EventType: pb.SupernodeEventType_SERVE_READY, Message: "Serve ready"}}}); err != nil { + return fmt.Errorf("send serve-ready: %w", err) + } + + // Stream file content in chunks + fi, err := os.Stat(decodedFilePath) + if err != nil { + return fmt.Errorf("stat decoded file: %w", err) + } + chunkSize := calculateOptimalChunkSize(fi.Size()) + f, err := os.Open(decodedFilePath) + if err != nil { + return fmt.Errorf("open decoded file: %w", err) + } + defer f.Close() + + buf := make([]byte, chunkSize) + for { + n, rerr := f.Read(buf) + if n > 0 { + if err := stream.Send(&pb.DownloadResponse{ResponseType: &pb.DownloadResponse_Chunk{Chunk: &pb.DataChunk{Data: buf[:n]}}}); err != nil { + return fmt.Errorf("send chunk: %w", err) + } + } + if rerr == io.EOF { + break + } + if rerr != nil { + return fmt.Errorf("read decoded file: %w", rerr) + } + } + + logtrace.Debug(ctx, "download stream completed", fields) + return nil +} diff --git a/supernode/transport/grpc/cascade/handler.go b/supernode/transport/grpc/cascade/handler.go index 96237b98..f2924ffc 100644 --- a/supernode/transport/grpc/cascade/handler.go +++ b/supernode/transport/grpc/cascade/handler.go @@ -1,20 +1,11 @@ package cascade import ( - "encoding/hex" - "fmt" - "hash" - "io" - "os" - "path/filepath" "time" pb "github.com/LumeraProtocol/supernode/v2/gen/supernode/action/cascade" - "github.com/LumeraProtocol/supernode/v2/pkg/errors" - "github.com/LumeraProtocol/supernode/v2/pkg/logtrace" tasks "github.com/LumeraProtocol/supernode/v2/pkg/task" cascadeService "github.com/LumeraProtocol/supernode/v2/supernode/cascade" - "lukechampine.com/blake3" ) type ActionServer struct { @@ -40,317 +31,3 @@ func NewCascadeActionServer(factory cascadeService.CascadeServiceFactory, tracke } return &ActionServer{factory: factory, tracker: tracker, uploadTimeout: uploadTO, downloadTimeout: downloadTO} } - -// calculateOptimalChunkSize returns an optimal chunk size based on file size -// to balance throughput and memory usage - -var ( - startedTask bool - handle *tasks.Handle -) - -func calculateOptimalChunkSize(fileSize int64) int { - const ( - minChunkSize = 64 * 1024 // 64 KB minimum - maxChunkSize = 4 * 1024 * 1024 // 4 MB maximum for 1GB+ files - smallFileThreshold = 1024 * 1024 // 1 MB - mediumFileThreshold = 50 * 1024 * 1024 // 50 MB - largeFileThreshold = 500 * 1024 * 1024 // 500 MB - ) - - var chunkSize int - - switch { - case fileSize <= smallFileThreshold: - chunkSize = minChunkSize - case fileSize <= mediumFileThreshold: - chunkSize = 256 * 1024 - case fileSize <= largeFileThreshold: - chunkSize = 1024 * 1024 - default: - chunkSize = maxChunkSize - } - - if chunkSize < minChunkSize { - chunkSize = minChunkSize - } - if chunkSize > maxChunkSize { - chunkSize = maxChunkSize - } - return chunkSize -} - -func (server *ActionServer) Register(stream pb.CascadeService_RegisterServer) error { - fields := logtrace.Fields{ - logtrace.FieldMethod: "Register", - logtrace.FieldModule: "CascadeActionServer", - } - - ctx := stream.Context() - logtrace.Info(ctx, "register: stream open", fields) - - const maxFileSize = 1 * 1024 * 1024 * 1024 // 1GB limit - - var ( - metadata *pb.Metadata - totalSize int - ) - - hasher, tempFile, tempFilePath, err := initializeHasherAndTempFile() - if err != nil { - fields[logtrace.FieldError] = err.Error() - logtrace.Error(ctx, "failed to initialize hasher and temp file", fields) - return fmt.Errorf("initializing hasher and temp file: %w", err) - } - defer func(tempFile *os.File) { - err := tempFile.Close() - if err != nil && !errors.Is(err, os.ErrClosed) { - fields[logtrace.FieldError] = err.Error() - logtrace.Warn(ctx, "error closing temp file", fields) - } - }(tempFile) - - for { - req, err := stream.Recv() - if err == io.EOF { - break - } - if err != nil { - fields[logtrace.FieldError] = err.Error() - logtrace.Error(ctx, "error receiving stream data", fields) - return fmt.Errorf("failed to receive stream data: %w", err) - } - - switch x := req.RequestType.(type) { - case *pb.RegisterRequest_Chunk: - if x.Chunk != nil { - if _, err := hasher.Write(x.Chunk.Data); err != nil { - fields[logtrace.FieldError] = err.Error() - logtrace.Error(ctx, "failed to write chunk to hasher", fields) - return fmt.Errorf("hashing error: %w", err) - } - if _, err := tempFile.Write(x.Chunk.Data); err != nil { - fields[logtrace.FieldError] = err.Error() - logtrace.Error(ctx, "failed to write chunk to file", fields) - return fmt.Errorf("file write error: %w", err) - } - totalSize += len(x.Chunk.Data) - if totalSize > maxFileSize { - fields[logtrace.FieldError] = "file size exceeds 1GB limit" - fields["total_size"] = totalSize - logtrace.Error(ctx, "upload rejected: file too large", fields) - return fmt.Errorf("file size %d exceeds maximum allowed size of 1GB", totalSize) - } - // Keep chunk logs at debug to avoid verbosity - logtrace.Debug(ctx, "received data chunk", logtrace.Fields{"chunk_size": len(x.Chunk.Data), "total_size_so_far": totalSize}) - } - case *pb.RegisterRequest_Metadata: - metadata = x.Metadata - // Set correlation ID for the rest of the flow - ctx = logtrace.CtxWithCorrelationID(ctx, metadata.ActionId) - fields[logtrace.FieldTaskID] = metadata.GetTaskId() - fields[logtrace.FieldActionID] = metadata.GetActionId() - logtrace.Info(ctx, "register: metadata received", fields) - // Start live task tracking on first metadata (covers remaining stream and processing) - if !startedTask { - startedTask = true - handle = tasks.StartWith(server.tracker, ctx, serviceCascadeUpload, metadata.ActionId, server.uploadTimeout) - defer handle.End(ctx) - } - } - } - - if metadata == nil { - logtrace.Error(ctx, "no metadata received in stream", fields) - return fmt.Errorf("no metadata received") - } - fields[logtrace.FieldTaskID] = metadata.GetTaskId() - fields[logtrace.FieldActionID] = metadata.GetActionId() - logtrace.Info(ctx, "register: stream upload complete", fields) - - if err := tempFile.Sync(); err != nil { - fields[logtrace.FieldError] = err.Error() - logtrace.Error(ctx, "failed to sync temp file", fields) - return fmt.Errorf("failed to sync temp file: %w", err) - } - - hash := hasher.Sum(nil) - hashHex := hex.EncodeToString(hash) - fields[logtrace.FieldHashHex] = hashHex - logtrace.Info(ctx, "register: hash computed", fields) - - targetPath, err := replaceTempDirWithTaskDir(metadata.GetTaskId(), tempFilePath, tempFile) - if err != nil { - fields[logtrace.FieldError] = err.Error() - logtrace.Error(ctx, "failed to replace temp dir with task dir", fields) - return fmt.Errorf("failed to replace temp dir with task dir: %w", err) - } - - task := server.factory.NewCascadeRegistrationTask() - logtrace.Info(ctx, "register: task start", fields) - err = task.Register(ctx, &cascadeService.RegisterRequest{ - TaskID: metadata.TaskId, - ActionID: metadata.ActionId, - DataHash: hash, - DataSize: totalSize, - FilePath: targetPath, - }, func(resp *cascadeService.RegisterResponse) error { - grpcResp := &pb.RegisterResponse{ - EventType: pb.SupernodeEventType(resp.EventType), - Message: resp.Message, - TxHash: resp.TxHash, - } - if err := stream.Send(grpcResp); err != nil { - logtrace.Error(ctx, "failed to send response to client", logtrace.Fields{logtrace.FieldError: err.Error()}) - return err - } - // Mirror event to Info logs for high-level tracing - logtrace.Info(ctx, "register: event", logtrace.Fields{"event_type": resp.EventType, "message": resp.Message, logtrace.FieldTxHash: resp.TxHash, logtrace.FieldActionID: metadata.ActionId, logtrace.FieldTaskID: metadata.TaskId}) - return nil - }) - if err != nil { - logtrace.Error(ctx, "registration task failed", logtrace.Fields{logtrace.FieldError: err.Error()}) - return fmt.Errorf("registration failed: %w", err) - } - logtrace.Info(ctx, "register: task ok", fields) - return nil -} - -func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeService_DownloadServer) error { - ctx := stream.Context() - fields := logtrace.Fields{ - logtrace.FieldMethod: "Download", - logtrace.FieldModule: "CascadeActionServer", - logtrace.FieldActionID: req.GetActionId(), - } - logtrace.Debug(ctx, "download request received", fields) - - // Start live task tracking for the entire download RPC (including file streaming) - dlHandle := tasks.StartWith(server.tracker, ctx, serviceCascadeDownload, req.GetActionId(), server.downloadTimeout) - defer dlHandle.End(ctx) - - // Prepare to capture decoded file path from task events - var decodedFilePath string - var tmpDir string - - task := server.factory.NewCascadeRegistrationTask() - // Run cascade task Download; stream events back to client - err := task.Download(ctx, &cascadeService.DownloadRequest{ActionID: req.GetActionId(), Signature: req.GetSignature()}, func(resp *cascadeService.DownloadResponse) error { - // Forward event to gRPC client - evt := &pb.DownloadResponse{ - ResponseType: &pb.DownloadResponse_Event{ - Event: &pb.DownloadEvent{ - EventType: pb.SupernodeEventType(resp.EventType), - Message: resp.Message, - }, - }, - } - if sendErr := stream.Send(evt); sendErr != nil { - return sendErr - } - // Capture decode-completed info for streaming - if resp.EventType == cascadeService.SupernodeEventTypeDecodeCompleted { - decodedFilePath = resp.FilePath - tmpDir = resp.DownloadedDir - } - return nil - }) - if err != nil { - fields[logtrace.FieldError] = err.Error() - logtrace.Error(ctx, "download task failed", fields) - return fmt.Errorf("download task failed: %w", err) - } - - if decodedFilePath == "" { - logtrace.Warn(ctx, "decode completed without file path", fields) - return nil - } - - // Notify client that server is ready to stream the file - logtrace.Debug(ctx, "download: serve ready", logtrace.Fields{"event_type": cascadeService.SupernodeEventTypeServeReady, logtrace.FieldActionID: req.GetActionId()}) - if err := stream.Send(&pb.DownloadResponse{ResponseType: &pb.DownloadResponse_Event{Event: &pb.DownloadEvent{EventType: pb.SupernodeEventType_SERVE_READY, Message: "Serve ready"}}}); err != nil { - return fmt.Errorf("send serve-ready: %w", err) - } - - // Stream file content in chunks - fi, err := os.Stat(decodedFilePath) - if err != nil { - return fmt.Errorf("stat decoded file: %w", err) - } - chunkSize := calculateOptimalChunkSize(fi.Size()) - f, err := os.Open(decodedFilePath) - if err != nil { - return fmt.Errorf("open decoded file: %w", err) - } - defer f.Close() - - buf := make([]byte, chunkSize) - for { - n, rerr := f.Read(buf) - if n > 0 { - if err := stream.Send(&pb.DownloadResponse{ResponseType: &pb.DownloadResponse_Chunk{Chunk: &pb.DataChunk{Data: append([]byte(nil), buf[:n]...)}}}); err != nil { - return fmt.Errorf("send chunk: %w", err) - } - } - if rerr == io.EOF { - break - } - if rerr != nil { - return fmt.Errorf("read decoded file: %w", rerr) - } - } - - // Cleanup temp directory if provided - if tmpDir != "" { - if cerr := task.CleanupDownload(ctx, tmpDir); cerr != nil { - logtrace.Warn(ctx, "cleanup of tmp dir failed", logtrace.Fields{"tmp_dir": tmpDir, logtrace.FieldError: cerr.Error()}) - } - } - - logtrace.Debug(ctx, "download stream completed", fields) - return nil -} - -// initializeHasherAndTempFile prepares a hasher and a temporary file to stream upload data into. -func initializeHasherAndTempFile() (hash.Hash, *os.File, string, error) { - // Create a temp directory for the upload - tmpDir, err := os.MkdirTemp("", "supernode-upload-*") - if err != nil { - return nil, nil, "", fmt.Errorf("create temp dir: %w", err) - } - - // Create a file within the temp directory - filePath := filepath.Join(tmpDir, "data.bin") - f, err := os.Create(filePath) - if err != nil { - return nil, nil, "", fmt.Errorf("create temp file: %w", err) - } - - // Create a BLAKE3 hasher (32 bytes output) - hasher := blake3.New(32, nil) - return hasher, f, filePath, nil -} - -// replaceTempDirWithTaskDir moves the uploaded file into a task-scoped directory -// and returns the new absolute path. -func replaceTempDirWithTaskDir(taskID, tempFilePath string, tempFile *os.File) (string, error) { - // Ensure data is flushed - _ = tempFile.Sync() - // Close now; deferred close may run later and is safe to ignore - _ = tempFile.Close() - - // Create a stable target directory under OS temp - targetDir := filepath.Join(os.TempDir(), "supernode", "uploads", taskID) - if err := os.MkdirAll(targetDir, 0700); err != nil { - return "", fmt.Errorf("create task dir: %w", err) - } - - newPath := filepath.Join(targetDir, filepath.Base(tempFilePath)) - if err := os.Rename(tempFilePath, newPath); err != nil { - return "", fmt.Errorf("move uploaded file: %w", err) - } - - // Attempt to cleanup the original temp directory - _ = os.RemoveAll(filepath.Dir(tempFilePath)) - return newPath, nil -} diff --git a/supernode/transport/grpc/cascade/handler_test.go b/supernode/transport/grpc/cascade/handler_test.go new file mode 100644 index 00000000..e457f444 --- /dev/null +++ b/supernode/transport/grpc/cascade/handler_test.go @@ -0,0 +1,106 @@ +package cascade + +import ( + "context" + "io" + "os" + "testing" + "time" + + pb "github.com/LumeraProtocol/supernode/v2/gen/supernode/action/cascade" + tasks "github.com/LumeraProtocol/supernode/v2/pkg/task" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +type fakeRegisterStream struct { + ctx context.Context + reqs []*pb.RegisterRequest + i int +} + +func (s *fakeRegisterStream) Send(*pb.RegisterResponse) error { return nil } + +func (s *fakeRegisterStream) Recv() (*pb.RegisterRequest, error) { + if s.i >= len(s.reqs) { + return nil, io.EOF + } + r := s.reqs[s.i] + s.i++ + return r, nil +} + +func (s *fakeRegisterStream) SetHeader(metadata.MD) error { return nil } +func (s *fakeRegisterStream) SendHeader(metadata.MD) error { return nil } +func (s *fakeRegisterStream) SetTrailer(metadata.MD) {} +func (s *fakeRegisterStream) Context() context.Context { + if s.ctx != nil { + return s.ctx + } + return context.Background() +} +func (s *fakeRegisterStream) SendMsg(interface{}) error { return nil } +func (s *fakeRegisterStream) RecvMsg(interface{}) error { return nil } + +func TestRegister_CleansTempDirOnHandlerError(t *testing.T) { + tmpRoot := t.TempDir() + + prevTmpDir, hadPrevTmpDir := os.LookupEnv("TMPDIR") + t.Cleanup(func() { + if hadPrevTmpDir { + _ = os.Setenv("TMPDIR", prevTmpDir) + } else { + _ = os.Unsetenv("TMPDIR") + } + }) + if err := os.Setenv("TMPDIR", tmpRoot); err != nil { + t.Fatalf("set TMPDIR: %v", err) + } + + server := &ActionServer{} + err := server.Register(&fakeRegisterStream{}) + if err == nil { + t.Fatalf("expected error, got nil") + } + + entries, rerr := os.ReadDir(tmpRoot) + if rerr != nil { + t.Fatalf("read tmpRoot: %v", rerr) + } + if len(entries) != 0 { + t.Fatalf("expected TMPDIR to be empty, found %d entries", len(entries)) + } +} + +func TestRegister_RejectsDuplicateActionID(t *testing.T) { + tmpRoot := t.TempDir() + + prevTmpDir, hadPrevTmpDir := os.LookupEnv("TMPDIR") + t.Cleanup(func() { + if hadPrevTmpDir { + _ = os.Setenv("TMPDIR", prevTmpDir) + } else { + _ = os.Unsetenv("TMPDIR") + } + }) + if err := os.Setenv("TMPDIR", tmpRoot); err != nil { + t.Fatalf("set TMPDIR: %v", err) + } + + tr := tasks.New() + actionID := "action-1" + tr.Start(serviceCascadeUpload, actionID) + + server := &ActionServer{tracker: tr, uploadTimeout: time.Second} + stream := &fakeRegisterStream{ + reqs: []*pb.RegisterRequest{ + {RequestType: &pb.RegisterRequest_Metadata{Metadata: &pb.Metadata{TaskId: "task-1", ActionId: actionID}}}, + }, + } + + err := server.Register(stream) + if status.Code(err) != codes.AlreadyExists { + t.Fatalf("expected AlreadyExists, got %v", err) + } +} diff --git a/supernode/transport/grpc/cascade/register.go b/supernode/transport/grpc/cascade/register.go new file mode 100644 index 00000000..6e6818cf --- /dev/null +++ b/supernode/transport/grpc/cascade/register.go @@ -0,0 +1,222 @@ +package cascade + +import ( + "encoding/hex" + "fmt" + "hash" + "io" + "os" + "path/filepath" + + pb "github.com/LumeraProtocol/supernode/v2/gen/supernode/action/cascade" + "github.com/LumeraProtocol/supernode/v2/pkg/errors" + "github.com/LumeraProtocol/supernode/v2/pkg/logtrace" + tasks "github.com/LumeraProtocol/supernode/v2/pkg/task" + cascadeService "github.com/LumeraProtocol/supernode/v2/supernode/cascade" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "lukechampine.com/blake3" +) + +func (server *ActionServer) Register(stream pb.CascadeService_RegisterServer) error { + fields := logtrace.Fields{ + logtrace.FieldMethod: "Register", + logtrace.FieldModule: "CascadeActionServer", + } + + ctx := stream.Context() + logtrace.Info(ctx, "register: stream open", fields) + + const maxFileSize = 1 * 1024 * 1024 * 1024 // 1GB limit + + var ( + metadata *pb.Metadata + totalSize int + uploadHandle *tasks.Handle + ) + + hasher, tempFile, tempDir, tempFilePath, err := initializeHasherAndTempFile() + if err != nil { + fields[logtrace.FieldError] = err.Error() + logtrace.Error(ctx, "failed to initialize hasher and temp file", fields) + return fmt.Errorf("initializing hasher and temp file: %w", err) + } + defer func() { + if tempDir == "" { + return + } + if err := os.RemoveAll(tempDir); err != nil { + fields[logtrace.FieldError] = err.Error() + fields["temp_dir"] = tempDir + logtrace.Warn(ctx, "failed to cleanup upload temp dir", fields) + } + }() + defer func(tempFile *os.File) { + err := tempFile.Close() + if err != nil && !errors.Is(err, os.ErrClosed) { + fields[logtrace.FieldError] = err.Error() + logtrace.Warn(ctx, "error closing temp file", fields) + } + }(tempFile) + + for { + req, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + fields[logtrace.FieldError] = err.Error() + logtrace.Error(ctx, "error receiving stream data", fields) + return fmt.Errorf("failed to receive stream data: %w", err) + } + + switch x := req.RequestType.(type) { + case *pb.RegisterRequest_Chunk: + if x.Chunk != nil { + if _, err := hasher.Write(x.Chunk.Data); err != nil { + fields[logtrace.FieldError] = err.Error() + logtrace.Error(ctx, "failed to write chunk to hasher", fields) + return fmt.Errorf("hashing error: %w", err) + } + if _, err := tempFile.Write(x.Chunk.Data); err != nil { + fields[logtrace.FieldError] = err.Error() + logtrace.Error(ctx, "failed to write chunk to file", fields) + return fmt.Errorf("file write error: %w", err) + } + totalSize += len(x.Chunk.Data) + if totalSize > maxFileSize { + fields[logtrace.FieldError] = "file size exceeds 1GB limit" + fields["total_size"] = totalSize + logtrace.Error(ctx, "upload rejected: file too large", fields) + return fmt.Errorf("file size %d exceeds maximum allowed size of 1GB", totalSize) + } + // Keep chunk logs at debug to avoid verbosity + logtrace.Debug(ctx, "received data chunk", logtrace.Fields{"chunk_size": len(x.Chunk.Data), "total_size_so_far": totalSize}) + } + case *pb.RegisterRequest_Metadata: + metadata = x.Metadata + // Set correlation ID for the rest of the flow + ctx = logtrace.CtxWithCorrelationID(ctx, metadata.ActionId) + fields[logtrace.FieldTaskID] = metadata.GetTaskId() + fields[logtrace.FieldActionID] = metadata.GetActionId() + logtrace.Info(ctx, "register: metadata received", fields) + actionID := metadata.GetActionId() + if actionID == "" { + return status.Error(codes.InvalidArgument, "missing action_id") + } + // Start live task tracking on first metadata (covers remaining stream and processing). + // Track by ActionID to prevent duplicate in-flight uploads for the same action. + if uploadHandle == nil { + h, herr := tasks.StartUniqueWith(server.tracker, ctx, serviceCascadeUpload, actionID, server.uploadTimeout) + if herr != nil { + if errors.Is(herr, tasks.ErrAlreadyRunning) { + return status.Errorf(codes.AlreadyExists, "upload already in progress for %s", actionID) + } + return herr + } + uploadHandle = h + defer uploadHandle.End(ctx) + } + } + } + + if metadata == nil { + logtrace.Error(ctx, "no metadata received in stream", fields) + return fmt.Errorf("no metadata received") + } + fields[logtrace.FieldTaskID] = metadata.GetTaskId() + fields[logtrace.FieldActionID] = metadata.GetActionId() + logtrace.Info(ctx, "register: stream upload complete", fields) + + if err := tempFile.Sync(); err != nil { + fields[logtrace.FieldError] = err.Error() + logtrace.Error(ctx, "failed to sync temp file", fields) + return fmt.Errorf("failed to sync temp file: %w", err) + } + + hash := hasher.Sum(nil) + hashHex := hex.EncodeToString(hash) + fields[logtrace.FieldHashHex] = hashHex + logtrace.Info(ctx, "register: hash computed", fields) + + targetPath, err := replaceTempDirWithTaskDir(metadata.GetTaskId(), tempFilePath, tempFile) + if err != nil { + fields[logtrace.FieldError] = err.Error() + logtrace.Error(ctx, "failed to replace temp dir with task dir", fields) + return fmt.Errorf("failed to replace temp dir with task dir: %w", err) + } + + task := server.factory.NewCascadeRegistrationTask() + logtrace.Info(ctx, "register: task start", fields) + err = task.Register(ctx, &cascadeService.RegisterRequest{ + TaskID: metadata.TaskId, + ActionID: metadata.ActionId, + DataHash: hash, + DataSize: totalSize, + FilePath: targetPath, + }, func(resp *cascadeService.RegisterResponse) error { + grpcResp := &pb.RegisterResponse{ + EventType: pb.SupernodeEventType(resp.EventType), + Message: resp.Message, + TxHash: resp.TxHash, + } + if err := stream.Send(grpcResp); err != nil { + logtrace.Error(ctx, "failed to send response to client", logtrace.Fields{logtrace.FieldError: err.Error()}) + return err + } + // Mirror event to Info logs for high-level tracing + logtrace.Info(ctx, "register: event", logtrace.Fields{"event_type": resp.EventType, "message": resp.Message, logtrace.FieldTxHash: resp.TxHash, logtrace.FieldActionID: metadata.ActionId, logtrace.FieldTaskID: metadata.TaskId}) + return nil + }) + if err != nil { + logtrace.Error(ctx, "registration task failed", logtrace.Fields{logtrace.FieldError: err.Error()}) + return fmt.Errorf("registration failed: %w", err) + } + logtrace.Info(ctx, "register: task ok", fields) + return nil +} + +// initializeHasherAndTempFile prepares a hasher and a temporary file to stream upload data into. +func initializeHasherAndTempFile() (hash.Hash, *os.File, string, string, error) { + // Create a temp directory for the upload + tmpDir, err := os.MkdirTemp("", "supernode-upload-*") + if err != nil { + return nil, nil, "", "", fmt.Errorf("create temp dir: %w", err) + } + + // Create a file within the temp directory + filePath := filepath.Join(tmpDir, "data.bin") + f, err := os.Create(filePath) + if err != nil { + _ = os.RemoveAll(tmpDir) + return nil, nil, "", "", fmt.Errorf("create temp file: %w", err) + } + + // Create a BLAKE3 hasher (32 bytes output) + hasher := blake3.New(32, nil) + return hasher, f, tmpDir, filePath, nil +} + +// replaceTempDirWithTaskDir moves the uploaded file into a task-scoped directory +// and returns the new absolute path. +func replaceTempDirWithTaskDir(taskID, tempFilePath string, tempFile *os.File) (string, error) { + // Ensure data is flushed + _ = tempFile.Sync() + // Close now; deferred close may run later and is safe to ignore + _ = tempFile.Close() + + // Create a stable target directory under OS temp + targetDir := filepath.Join(os.TempDir(), "supernode", "uploads", taskID) + if err := os.MkdirAll(targetDir, 0700); err != nil { + return "", fmt.Errorf("create task dir: %w", err) + } + + newPath := filepath.Join(targetDir, filepath.Base(tempFilePath)) + if err := os.Rename(tempFilePath, newPath); err != nil { + return "", fmt.Errorf("move uploaded file: %w", err) + } + + // Attempt to cleanup the original temp directory + _ = os.RemoveAll(filepath.Dir(tempFilePath)) + return newPath, nil +}