diff --git a/network/dag/interface.go b/network/dag/interface.go index 8b96036091..616bcd7c88 100644 --- a/network/dag/interface.go +++ b/network/dag/interface.go @@ -55,6 +55,12 @@ type State interface { // If the transaction already exists, nothing is added and no observers are notified. // The payload may be passed as well. Allowing for better notification of observers Add(ctx context.Context, transactions Transaction, payload []byte) error + // AddMany adds multiple transactions to the DAG in a single write transaction, + // requiring only a single fsync. Transactions are processed in order so that later + // transactions can reference earlier ones in the same batch. + // Returns the number of transactions successfully added and the first error encountered. + // Successfully added transactions are committed even when a later transaction fails. + AddMany(ctx context.Context, transactions []Transaction, payloads [][]byte) (int, error) // FindBetweenLC finds all transactions which lamport clock value lies between startInclusive and endExclusive. // They are returned in order: first sorted on lamport clock value, then on transaction reference (byte order). FindBetweenLC(ctx context.Context, startInclusive uint32, endExclusive uint32) ([]Transaction, error) diff --git a/network/dag/mock.go b/network/dag/mock.go index cf62a78aba..6449b450b1 100644 --- a/network/dag/mock.go +++ b/network/dag/mock.go @@ -58,6 +58,21 @@ func (mr *MockStateMockRecorder) Add(ctx, transactions, payload any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockState)(nil).Add), ctx, transactions, payload) } +// AddMany mocks base method. +func (m *MockState) AddMany(ctx context.Context, transactions []Transaction, payloads [][]byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddMany", ctx, transactions, payloads) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddMany indicates an expected call of AddMany. +func (mr *MockStateMockRecorder) AddMany(ctx, transactions, payloads any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddMany", reflect.TypeOf((*MockState)(nil).AddMany), ctx, transactions, payloads) +} + // Configure mocks base method. func (m *MockState) Configure(config core.ServerConfig) error { m.ctrl.T.Helper() diff --git a/network/dag/state.go b/network/dag/state.go index fe60a74525..f1f24e1b6d 100644 --- a/network/dag/state.go +++ b/network/dag/state.go @@ -219,6 +219,96 @@ func (s *state) Add(ctx context.Context, transaction Transaction, payload []byte }), stoabs.WithWriteLock()) } +func (s *state) AddMany(ctx context.Context, transactions []Transaction, payloads [][]byte) (int, error) { + added := 0 + var txEvents []Event + var payloadEvents []Event + var firstErr error + + err := s.db.Write(ctx, func(tx stoabs.WriteTx) error { + for i, transaction := range transactions { + if ctx.Err() != nil { + firstErr = ctx.Err() + break + } + // Skip if already present + if s.graph.isPresent(tx, transaction.Ref()) { + continue + } + + // Verify within the write TX so that earlier TXs in the batch are visible + if err := s.verifyTX(tx, transaction); err != nil { + firstErr = fmt.Errorf("transaction verification failed (tx=%s): %w", transaction.Ref(), err) + break + } + + payload := payloads[i] + if payload != nil { + payloadHash := hash.SHA256Sum(payload) + if !transaction.PayloadHash().Equals(payloadHash) { + firstErr = fmt.Errorf("tx.PayloadHash does not match hash of payload (tx=%s)", transaction.Ref()) + break + } + if err := s.payloadStore.writePayload(tx, payloadHash, payload); err != nil { + firstErr = err + break + } + event := Event{ + Type: PayloadEventType, + Hash: transaction.Ref(), + Transaction: transaction, + Payload: payload, + } + if err := s.saveEvent(tx, event); err != nil { + firstErr = err + break + } + payloadEvents = append(payloadEvents, event) + } + + if err := s.graph.add(tx, transaction); err != nil { + firstErr = err + break + } + event := Event{ + Type: TransactionEventType, + Hash: transaction.Ref(), + Transaction: transaction, + Payload: payload, + } + if err := s.saveEvent(tx, event); err != nil { + firstErr = err + break + } + txEvents = append(txEvents, event) + + if err := s.updateState(tx, transaction); err != nil { + firstErr = err + break + } + added++ + } + // Always return nil to commit what we have, even if a TX failed + return nil + }, stoabs.OnRollback(func() { + log.Logger().Warn("Reloading the XOR and IBLT trees due to a DB transaction Rollback") + s.loadState(ctx) + }), stoabs.AfterCommit(func() { + for _, event := range txEvents { + s.notify(event) + } + for _, event := range payloadEvents { + s.notify(event) + } + }), stoabs.AfterCommit(func() { + s.transactionCount.Add(float64(added)) + }), stoabs.WithWriteLock()) + if err != nil { + return added, err + } + return added, firstErr +} + func (s *state) updateState(tx stoabs.WriteTx, transaction Transaction) error { clock := transaction.Clock() for { diff --git a/network/transport/v2/transactionlist_handler.go b/network/transport/v2/transactionlist_handler.go index ba4768c1c9..4dda2db288 100644 --- a/network/transport/v2/transactionlist_handler.go +++ b/network/transport/v2/transactionlist_handler.go @@ -94,28 +94,33 @@ func (p *protocol) handleTransactionList(ctx context.Context, connection grpc.Co return err } + // Validate that all public transactions include a payload before adding + payloads := make([][]byte, len(txs)) for i, tx := range txs { - if ctx.Err() != nil { - // For loop might be long-running, support cancellation - break - } // TODO does this always trigger fetching missing payloads? (through observer on DAG) Prolly not for v2 if len(tx.PAL()) == 0 && len(msg.Transactions[i].Payload) == 0 { return fmt.Errorf("peer did not provide payload for transaction (tx=%s)", tx.Ref()) } - if err = p.state.Add(ctx, tx, msg.Transactions[i].Payload); err != nil { - if errors.Is(err, dag.ErrPreviousTransactionMissing) { - p.cMan.done(cid) - log.Logger(). - WithFields(connection.Peer().ToFields()). - WithField(core.LogFieldConversationID, cid). - WithField(core.LogFieldTransactionRef, tx.Ref()). - Warn("Ignoring remainder of TransactionList due to missing prevs") - xor, clock := p.state.XOR(dag.MaxLamportClock) - return p.sender.sendState(connection, xor, clock) - } - return fmt.Errorf("unable to add received transaction to DAG (tx=%s): %w", tx.Ref(), err) + payloads[i] = msg.Transactions[i].Payload + } + + if ctx.Err() != nil { + return nil + } + + added, err := p.state.AddMany(ctx, txs, payloads) + if err != nil { + if errors.Is(err, dag.ErrPreviousTransactionMissing) { + p.cMan.done(cid) + log.Logger(). + WithFields(connection.Peer().ToFields()). + WithField(core.LogFieldConversationID, cid). + WithField(core.LogFieldTransactionRef, txs[added].Ref()). + Warn("Ignoring remainder of TransactionList due to missing prevs") + xor, clock := p.state.XOR(dag.MaxLamportClock) + return p.sender.sendState(connection, xor, clock) } + return fmt.Errorf("unable to add received transaction to DAG (tx=%s): %w", txs[added].Ref(), err) } if msg.MessageNumber >= msg.TotalMessages { diff --git a/network/transport/v2/transactionlist_handler_test.go b/network/transport/v2/transactionlist_handler_test.go index 0971900cbc..7b90a1ca09 100644 --- a/network/transport/v2/transactionlist_handler_test.go +++ b/network/transport/v2/transactionlist_handler_test.go @@ -79,7 +79,7 @@ func TestProtocol_handleTransactionList(t *testing.T) { p, mocks := newTestProtocol(t, nil) conversation := p.cMan.startConversation(request, peer) envelope := envelopeWithConversation(conversation) - mocks.State.EXPECT().Add(context.Background(), tx, payload).Return(nil) + mocks.State.EXPECT().AddMany(context.Background(), []dag.Transaction{tx}, [][]byte{payload}).Return(1, nil) err := p.handleTransactionList(context.Background(), connection, envelope) @@ -102,7 +102,7 @@ func TestProtocol_handleTransactionList(t *testing.T) { p, mocks := newTestProtocol(t, nil) conversation := p.cMan.startConversation(request, peer) envelope := envelopeWithConversation(conversation) - mocks.State.EXPECT().Add(context.Background(), tx, payload).Return(nil) + mocks.State.EXPECT().AddMany(context.Background(), []dag.Transaction{tx}, [][]byte{payload}).Return(1, nil) err := p.handleTransactionList(context.Background(), connection, envelope) @@ -113,7 +113,7 @@ func TestProtocol_handleTransactionList(t *testing.T) { p, mocks := newTestProtocol(t, nil) conversation := p.cMan.startConversation(request, peer) envelope := envelopeWithConversation(conversation) - mocks.State.EXPECT().Add(context.Background(), tx, payload).Return(dag.ErrPreviousTransactionMissing) + mocks.State.EXPECT().AddMany(context.Background(), []dag.Transaction{tx}, [][]byte{payload}).Return(0, dag.ErrPreviousTransactionMissing) mocks.State.EXPECT().XOR(uint32(dag.MaxLamportClock)).Return(hash.FromSlice([]byte("stateXor")), uint32(7)) mocks.Sender.EXPECT().sendState(connection, hash.FromSlice([]byte("stateXor")), uint32(7)) @@ -127,7 +127,7 @@ func TestProtocol_handleTransactionList(t *testing.T) { p, mocks := newTestProtocol(t, nil) conversation := p.cMan.startConversation(request, peer) envelope := envelopeWithConversation(conversation) - mocks.State.EXPECT().Add(context.Background(), tx, payload).Return(nil) + mocks.State.EXPECT().AddMany(context.Background(), []dag.Transaction{tx}, [][]byte{payload}).Return(1, nil) err := p.handleTransactionList(context.Background(), connection, envelope) @@ -142,7 +142,7 @@ func TestProtocol_handleTransactionList(t *testing.T) { conversation := p.cMan.startConversation(request2, peer) cStartTime := conversation.expiry.Add(-1 * time.Millisecond) conversation.expiry = cStartTime - mocks.State.EXPECT().Add(context.Background(), tx, payload).Return(nil) + mocks.State.EXPECT().AddMany(context.Background(), []dag.Transaction{tx}, [][]byte{payload}).Return(1, nil) err := p.handleTransactionList(context.Background(), connection, &Envelope{Message: &Envelope_TransactionList{ TransactionList: &TransactionList{ @@ -163,7 +163,7 @@ func TestProtocol_handleTransactionList(t *testing.T) { p, mocks := newTestProtocol(t, nil) conversation := p.cMan.startConversation(request, peer) envelope := envelopeWithConversation(conversation) - mocks.State.EXPECT().Add(context.Background(), tx, payload).Return(errors.New("custom")) + mocks.State.EXPECT().AddMany(context.Background(), []dag.Transaction{tx}, [][]byte{payload}).Return(0, errors.New("custom")) err := p.handleTransactionList(context.Background(), connection, envelope)