From 71a6e740a3253c0f6419c8c23f7df1e25c7ce375 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 19 May 2026 13:17:03 +0200 Subject: [PATCH 01/14] feat(migration): headstate migration --- core/state/accessors.go | 20 ++- core/state/state.go | 2 +- migration/headstate/committer.go | 49 +++++++ migration/headstate/counter.go | 47 ++++++ migration/headstate/ingestor.go | 95 ++++++++++++ migration/headstate/migrator.go | 155 ++++++++++++++++++++ migration/headstate/migrator_test.go | 209 +++++++++++++++++++++++++++ node/migration.go | 4 +- 8 files changed, 578 insertions(+), 3 deletions(-) create mode 100644 migration/headstate/committer.go create mode 100644 migration/headstate/counter.go create mode 100644 migration/headstate/ingestor.go create mode 100644 migration/headstate/migrator.go create mode 100644 migration/headstate/migrator_test.go diff --git a/core/state/accessors.go b/core/state/accessors.go index a121ab7808..2b289ae70b 100644 --- a/core/state/accessors.go +++ b/core/state/accessors.go @@ -41,7 +41,25 @@ func HasContract(r db.KeyValueReader, addr *felt.Felt) (bool, error) { return r.Has(key) } -func WriteContract(w db.KeyValueWriter, addr *felt.Felt, contract *stateContract) error { +// WriteContract writes a Contract record from raw fields. Used by the running +// node (via writeContract on a fully-built stateContract) and by the deprecated +// → new state migration (with StorageRoot left zero — the new state lazily +// backfills it on the contract's first storage write). +func WriteContract( + w db.KeyValueWriter, + addr *felt.Felt, + nonce, classHash felt.Felt, + deployHeight uint64, +) error { + contract := stateContract{ + Nonce: nonce, + ClassHash: classHash, + DeployedHeight: deployHeight, + } + return writeContract(w, addr, &contract) +} + +func writeContract(w db.KeyValueWriter, addr *felt.Felt, contract *stateContract) error { key := db.ContractKey(addr) data, err := contract.MarshalBinary() if err != nil { diff --git a/core/state/state.go b/core/state/state.go index 6ea9cb5543..edd6b8d234 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -376,7 +376,7 @@ func (s *State) flush( return err } } else { // updated - if err := WriteContract(s.batch, &addr, obj.contract); err != nil { + if err := writeContract(s.batch, &addr, obj.contract); err != nil { return err } } diff --git a/migration/headstate/committer.go b/migration/headstate/committer.go new file mode 100644 index 0000000000..962cac7a2e --- /dev/null +++ b/migration/headstate/committer.go @@ -0,0 +1,49 @@ +package headstate + +import ( + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/migration/pipeline" + "github.com/NethermindEth/juno/migration/semaphore" + "github.com/NethermindEth/juno/utils/log" + "go.uber.org/zap" +) + +type committer struct { + counter counter + logger log.StructuredLogger + batchSemaphore semaphore.ResourceSemaphore[db.Batch] +} + +var _ pipeline.State[task, struct{}] = (*committer)(nil) + +func newCommitter( + logger log.StructuredLogger, + batchSemaphore semaphore.ResourceSemaphore[db.Batch], +) *committer { + return &committer{ + logger: logger, + counter: newCounter(logger, timeLogRate), + batchSemaphore: batchSemaphore, + } +} + +func (c *committer) Run(_ int, t task, _ chan<- struct{}) error { + c.logger.Debug( + "writing batch", + zap.Int("addrCount", t.addrCount), + zap.Int("batchSize", t.batch.Size()), + ) + + byteSize := uint64(t.batch.Size()) + if err := t.batch.Write(); err != nil { + return err + } + + c.counter.log(byteSize, t.addrCount) + c.batchSemaphore.Put() + return nil +} + +func (c *committer) Done(int, chan<- struct{}) error { + return nil +} diff --git a/migration/headstate/counter.go b/migration/headstate/counter.go new file mode 100644 index 0000000000..cf3f9b28e9 --- /dev/null +++ b/migration/headstate/counter.go @@ -0,0 +1,47 @@ +package headstate + +import ( + "time" + + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/utils/log" + "go.uber.org/zap" +) + +type counter struct { + logger log.StructuredLogger + timeLogRate time.Duration + start time.Time + size uint64 + addrCount uint64 +} + +func newCounter(logger log.StructuredLogger, timeLogRate time.Duration) counter { + return counter{ + logger: logger, + timeLogRate: timeLogRate, + start: time.Now(), + } +} + +func (c *counter) log(byteSize uint64, addrCount int) { + c.size += byteSize + c.addrCount += uint64(addrCount) + + now := time.Now() + elapsed := now.Sub(c.start).Seconds() + if elapsed > float64(c.timeLogRate.Seconds()) { + mbs := float64(c.size) / float64(db.Megabyte) + c.logger.Info( + "write speed", + zap.Float64("MB", mbs), + zap.Float64("MB/s", mbs/elapsed), + zap.Uint64("contracts", c.addrCount), + zap.Float64("contracts/s", float64(c.addrCount)/elapsed), + zap.Float64("time", elapsed), + ) + c.start = now + c.size = 0 + c.addrCount = 0 + } +} diff --git a/migration/headstate/ingestor.go b/migration/headstate/ingestor.go new file mode 100644 index 0000000000..cfe60388e4 --- /dev/null +++ b/migration/headstate/ingestor.go @@ -0,0 +1,95 @@ +package headstate + +import ( + "errors" + "fmt" + + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/migration/pipeline" + "github.com/NethermindEth/juno/migration/semaphore" +) + +type ingestor struct { + database db.KeyValueReader + batchSemaphore semaphore.ResourceSemaphore[db.Batch] + tasks []task +} + +func newIngestor( + database db.KeyValueReader, + batchSemaphore semaphore.ResourceSemaphore[db.Batch], +) *ingestor { + tasks := make([]task, ingestorCount) + for i := range tasks { + tasks[i] = task{batch: batchSemaphore.GetBlocking()} + } + return &ingestor{ + database: database, + batchSemaphore: batchSemaphore, + tasks: tasks, + } +} + +var _ pipeline.State[felt.Address, task] = (*ingestor)(nil) + +func (c *ingestor) Run(index int, addr felt.Address, outputs chan<- task) error { + t := &c.tasks[index] + + sizeBefore := t.batch.Size() + if err := c.ingestAddress(t.batch, addr); err != nil { + return err + } + if t.batch.Size() > sizeBefore { + t.addrCount++ + } + + if t.batch.Size() >= targetBatchByteSize { + outputs <- task{batch: t.batch, addrCount: t.addrCount} + t.addrCount = 0 + t.batch = c.batchSemaphore.GetBlocking() + } + return nil +} + +func (c *ingestor) Done(index int, outputs chan<- task) error { + outputs <- c.tasks[index] + return nil +} + +func (c *ingestor) ingestAddress(batch db.Batch, addr felt.Address) error { + addrFelt := felt.Felt(addr) + + already, err := state.HasContract(c.database, &addrFelt) + if err != nil { + return fmt.Errorf("HasContract(%s): %w", &addrFelt, err) + } + if already { + return nil + } + + classHash, err := core.GetContractClassHash(c.database, &addrFelt) + if err != nil { + return fmt.Errorf("GetContractClassHash(%s): %w", &addrFelt, err) + } + + nonce, err := core.GetContractNonce(c.database, &addrFelt) + if err != nil { + if !errors.Is(err, db.ErrKeyNotFound) { + return fmt.Errorf("GetContractNonce(%s): %w", &addrFelt, err) + } + nonce = felt.Zero + } + + height, err := core.GetContractDeploymentHeight(c.database, &addrFelt) + if err != nil { + return fmt.Errorf("GetContractDeploymentHeight(%s): %w", &addrFelt, err) + } + + if err := state.WriteContract(batch, &addrFelt, nonce, classHash, height); err != nil { + return fmt.Errorf("WriteContract(%s): %w", &addrFelt, err) + } + return nil +} diff --git a/migration/headstate/migrator.go b/migration/headstate/migrator.go new file mode 100644 index 0000000000..6f9b43c840 --- /dev/null +++ b/migration/headstate/migrator.go @@ -0,0 +1,155 @@ +package headstate + +import ( + "context" + "fmt" + "iter" + "time" + + "github.com/NethermindEth/juno/blockchain/networks" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/dbutils" + "github.com/NethermindEth/juno/migration" + "github.com/NethermindEth/juno/migration/pipeline" + "github.com/NethermindEth/juno/migration/semaphore" + "github.com/NethermindEth/juno/utils/log" +) + +const ( + batchByteSize = 128 * db.Megabyte + targetBatchByteSize = 96 * db.Megabyte + ingestorCount = 4 + timeLogRate = 5 * time.Second +) + +type task struct { + batch db.Batch + addrCount int +} + +var ( + shouldRerun = []byte{} + shouldNotRerun = []byte(nil) +) + +func migrateAddresses( + ctx context.Context, + database db.KeyValueStore, + logger log.StructuredLogger, + addresses iter.Seq[felt.Address], +) pipeline.Result { + batchSemaphore := semaphore.New( + ingestorCount+1, + func() db.Batch { + return database.NewBatchWithSize(batchByteSize) + }, + ) + + source := pipeline.Source(addresses) + + ingestorPipeline := pipeline.New( + source, + ingestorCount, + newIngestor(database, batchSemaphore), + ) + + committerPipeline := pipeline.New( + ingestorPipeline, + 1, + newCommitter(logger, batchSemaphore), + ) + + _, wait := committerPipeline.Run(ctx) + return wait() +} + +var _ migration.Migration = (*Migrator)(nil) + +type Migrator struct{} + +func (Migrator) Before([]byte) error { + return nil +} + +func (Migrator) Migrate( + ctx context.Context, + database db.KeyValueStore, + _ *networks.Network, + logger log.StructuredLogger, +) ([]byte, error) { + hasPending, err := hasPendingAddresses(database) + if err != nil { + return shouldRerun, err + } + if !hasPending { + return shouldNotRerun, wipeDeprecatedBuckets(database) + } + + addressesIter, sourceErr := pendingAddresses(database) + res := migrateAddresses(ctx, database, logger, addressesIter) + + if err := sourceErr(); err != nil { + return shouldRerun, err + } + if res.Err != nil || !res.IsDone { + return shouldRerun, res.Err + } + + return shouldNotRerun, wipeDeprecatedBuckets(database) +} + +func hasPendingAddresses(r db.KeyValueReader) (bool, error) { + prefix := db.ContractClassHash.Key() + it, err := r.NewIterator(prefix, true) + if err != nil { + return false, err + } + defer it.Close() + return it.First(), nil +} + +func pendingAddresses(r db.KeyValueReader) (iter.Seq[felt.Address], func() error) { + var iterErr error + seq := func(yield func(felt.Address) bool) { + prefix := db.ContractClassHash.Key() + it, err := r.NewIterator(prefix, true) + if err != nil { + iterErr = err + return + } + defer it.Close() + + for valid := it.First(); valid; valid = it.Next() { + key := it.Key() + if len(key) != len(prefix)+felt.Bytes { + iterErr = fmt.Errorf( + "malformed ContractClassHash key: len %d, want %d", + len(key), + len(prefix)+felt.Bytes, + ) + return + } + f := felt.FromBytes[felt.Felt](key[len(prefix):]) + if !yield(felt.Address(f)) { + return + } + } + } + return seq, func() error { return iterErr } +} + +func wipeDeprecatedBuckets(database db.KeyValueStore) error { + for _, bucket := range []db.Bucket{ + db.ContractClassHash, + db.ContractNonce, + db.ContractDeploymentHeight, + } { + start := bucket.Key() + end := dbutils.UpperBound(start) + if err := database.DeleteRange(start, end); err != nil { + return err + } + } + return nil +} diff --git a/migration/headstate/migrator_test.go b/migration/headstate/migrator_test.go new file mode 100644 index 0000000000..e8b281a727 --- /dev/null +++ b/migration/headstate/migrator_test.go @@ -0,0 +1,209 @@ +package headstate_test + +import ( + "context" + "testing" + + "github.com/NethermindEth/juno/blockchain/networks" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/NethermindEth/juno/migration/headstate" + "github.com/NethermindEth/juno/utils/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type contractData struct { + addr felt.Felt + classHash felt.Felt + nonce felt.Felt // felt.Zero means "do not write nonce entry" + height uint64 +} + +func seedDeprecated(t *testing.T, memDB db.KeyValueStore, seeds []contractData) { + t.Helper() + for i := range seeds { + s := &seeds[i] + require.NoError(t, core.WriteContractClassHash(memDB, &s.addr, &s.classHash)) + if !s.nonce.IsZero() { + require.NoError(t, core.WriteContractNonce(memDB, &s.addr, &s.nonce)) + } + require.NoError(t, core.WriteContractDeploymentHeight(memDB, &s.addr, s.height)) + } +} + +func bucketKeyCount(t *testing.T, r db.KeyValueReader, bucket db.Bucket) int { + t.Helper() + it, err := r.NewIterator(bucket.Key(), true) + require.NoError(t, err) + defer it.Close() + count := 0 + for valid := it.First(); valid; valid = it.Next() { + count++ + } + return count +} + +func TestMigrate_EmptyDB(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + res, err := headstate.Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + assert.Nil(t, res, "empty DB must complete with no intermediate state") +} + +func TestMigrate_ConsolidatesAddresses(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + seeds := []contractData{ + { + addr: felt.FromUint64[felt.Felt](1), + classHash: felt.FromUint64[felt.Felt](170), + nonce: felt.FromUint64[felt.Felt](7), + height: 100, + }, + { + addr: felt.FromUint64[felt.Felt](2), + classHash: felt.FromUint64[felt.Felt](187), + nonce: felt.Zero, // never updated + height: 200, + }, + { + addr: felt.FromUint64[felt.Felt](3), + classHash: felt.FromUint64[felt.Felt](204), + nonce: felt.FromUint64[felt.Felt](66), + height: 300, + }, + } + seedDeprecated(t, memDB, seeds) + + res, err := headstate.Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + assert.Nil(t, res) + + for i := range seeds { + s := &seeds[i] + got, err := state.GetContract(memDB, &s.addr) + require.NoErrorf(t, err, "missing Contract for %s", &s.addr) + assert.Equal(t, s.classHash, got.ClassHash, "ClassHash for %s", &s.addr) + assert.Equal(t, s.nonce, got.Nonce, "Nonce for %s", &s.addr) + assert.Equal(t, s.height, got.DeployedHeight, "DeployedHeight for %s", &s.addr) + assert.True(t, got.StorageRoot.IsZero(), "StorageRoot must be zero (lazy)") + } + + for _, bucket := range []db.Bucket{ + db.ContractClassHash, + db.ContractNonce, + db.ContractDeploymentHeight, + } { + assert.Equal(t, 0, bucketKeyCount(t, memDB, bucket), "old bucket %v must be empty", bucket) + } +} + +func TestMigrate_SkipsAlreadyMigrated(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + // addrDone has been migrated already (Contract record exists). Its old + // keys are still around (simulating a previous partial run that wrote + // the Contract record but didn't reach wipeDeprecatedBuckets). + addrDone := felt.FromUint64[felt.Felt](1) + doneClassHash := felt.FromUint64[felt.Felt](170) + doneNonce := felt.FromUint64[felt.Felt](57005) + + require.NoError(t, state.WriteContract(memDB, &addrDone, doneNonce, doneClassHash, 111)) + require.NoError(t, core.WriteContractClassHash(memDB, &addrDone, &doneClassHash)) + require.NoError(t, core.WriteContractDeploymentHeight(memDB, &addrDone, 111)) + + addrPending := felt.FromUint64[felt.Felt](2) + pendingClassHash := felt.FromUint64[felt.Felt](187) + pendingNonce := felt.FromUint64[felt.Felt](9) + + require.NoError(t, core.WriteContractClassHash(memDB, &addrPending, &pendingClassHash)) + require.NoError(t, core.WriteContractNonce(memDB, &addrPending, &pendingNonce)) + require.NoError(t, core.WriteContractDeploymentHeight(memDB, &addrPending, 222)) + + res, err := headstate.Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + assert.Nil(t, res) + + // addrDone's pre-existing Contract was preserved (not overwritten). + done, err := state.GetContract(memDB, &addrDone) + require.NoError(t, err) + assert.Equal(t, doneNonce, done.Nonce, "addrDone's pre-existing Contract must not be overwritten") + assert.Equal(t, uint64(111), done.DeployedHeight) + + // addrPending got migrated. + pending, err := state.GetContract(memDB, &addrPending) + require.NoError(t, err) + assert.Equal(t, pendingClassHash, pending.ClassHash) + assert.Equal(t, pendingNonce, pending.Nonce) + assert.Equal(t, uint64(222), pending.DeployedHeight) + + for _, bucket := range []db.Bucket{ + db.ContractClassHash, + db.ContractNonce, + db.ContractDeploymentHeight, + } { + assert.Equal(t, 0, bucketKeyCount(t, memDB, bucket), "old bucket %v must be empty", bucket) + } +} + +func TestMigrate_Idempotent(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + seeds := []contractData{ + { + addr: felt.FromUint64[felt.Felt](1), + classHash: felt.FromUint64[felt.Felt](170), + nonce: felt.FromUint64[felt.Felt](7), + height: 100, + }, + { + addr: felt.FromUint64[felt.Felt](2), + classHash: felt.FromUint64[felt.Felt](187), + nonce: felt.Zero, + height: 200, + }, + } + seedDeprecated(t, memDB, seeds) + + for i := 0; i < 3; i++ { + res, err := headstate.Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoErrorf(t, err, "run %d", i) + assert.Nilf(t, res, "run %d intermediate state", i) + } + + for i := range seeds { + s := &seeds[i] + got, err := state.GetContract(memDB, &s.addr) + require.NoError(t, err) + assert.Equal(t, s.classHash, got.ClassHash) + } +} diff --git a/node/migration.go b/node/migration.go index 4b7576d403..71f387597f 100644 --- a/node/migration.go +++ b/node/migration.go @@ -11,6 +11,7 @@ import ( "github.com/NethermindEth/juno/migration" "github.com/NethermindEth/juno/migration/blocktransactions" "github.com/NethermindEth/juno/migration/deprecated" //nolint:staticcheck,nolintlint,lll // ignore statick check package will be removed in future, nolinlint because main config does not check + "github.com/NethermindEth/juno/migration/headstate" "github.com/NethermindEth/juno/migration/historyprunner" "github.com/NethermindEth/juno/utils/log" ) @@ -25,7 +26,8 @@ func registerMigrations(cfg *Config) *migration.Registry { historyprunner.New(cfg.RetainedBlocks), cfg.Prune, PruneModeFlag, - ) + ). + WithOptional(&headstate.Migrator{}, cfg.NewState, "new-state") return registry } From c2eaccedc1025ac0f2bb805916c70a4517987215 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 19 May 2026 23:48:00 +0200 Subject: [PATCH 02/14] chore: self review --- core/state/accessors.go | 4 -- migration/headstate/committer.go | 4 +- migration/headstate/counter.go | 22 +++--- migration/headstate/ingestor.go | 6 +- migration/headstate/migrator.go | 102 +++++++++++++++------------ migration/headstate/migrator_test.go | 2 +- 6 files changed, 72 insertions(+), 68 deletions(-) diff --git a/core/state/accessors.go b/core/state/accessors.go index 2b289ae70b..4996937076 100644 --- a/core/state/accessors.go +++ b/core/state/accessors.go @@ -41,10 +41,6 @@ func HasContract(r db.KeyValueReader, addr *felt.Felt) (bool, error) { return r.Has(key) } -// WriteContract writes a Contract record from raw fields. Used by the running -// node (via writeContract on a fully-built stateContract) and by the deprecated -// → new state migration (with StorageRoot left zero — the new state lazily -// backfills it on the contract's first storage write). func WriteContract( w db.KeyValueWriter, addr *felt.Felt, diff --git a/migration/headstate/committer.go b/migration/headstate/committer.go index 962cac7a2e..cf5f15c03f 100644 --- a/migration/headstate/committer.go +++ b/migration/headstate/committer.go @@ -30,7 +30,7 @@ func newCommitter( func (c *committer) Run(_ int, t task, _ chan<- struct{}) error { c.logger.Debug( "writing batch", - zap.Int("addrCount", t.addrCount), + zap.Int("completedAddrs", t.completedAddrs), zap.Int("batchSize", t.batch.Size()), ) @@ -39,7 +39,7 @@ func (c *committer) Run(_ int, t task, _ chan<- struct{}) error { return err } - c.counter.log(byteSize, t.addrCount) + c.counter.log(byteSize, t.completedAddrs) c.batchSemaphore.Put() return nil } diff --git a/migration/headstate/counter.go b/migration/headstate/counter.go index cf3f9b28e9..db427d11fc 100644 --- a/migration/headstate/counter.go +++ b/migration/headstate/counter.go @@ -9,11 +9,11 @@ import ( ) type counter struct { - logger log.StructuredLogger - timeLogRate time.Duration - start time.Time - size uint64 - addrCount uint64 + logger log.StructuredLogger + timeLogRate time.Duration + start time.Time + size uint64 + completedAddrs uint64 } func newCounter(logger log.StructuredLogger, timeLogRate time.Duration) counter { @@ -24,24 +24,24 @@ func newCounter(logger log.StructuredLogger, timeLogRate time.Duration) counter } } -func (c *counter) log(byteSize uint64, addrCount int) { +func (c *counter) log(byteSize uint64, completedAddrs int) { c.size += byteSize - c.addrCount += uint64(addrCount) + c.completedAddrs += uint64(completedAddrs) now := time.Now() elapsed := now.Sub(c.start).Seconds() - if elapsed > float64(c.timeLogRate.Seconds()) { + if elapsed > c.timeLogRate.Seconds() { mbs := float64(c.size) / float64(db.Megabyte) c.logger.Info( "write speed", zap.Float64("MB", mbs), zap.Float64("MB/s", mbs/elapsed), - zap.Uint64("contracts", c.addrCount), - zap.Float64("contracts/s", float64(c.addrCount)/elapsed), + zap.Uint64("completedContracts", c.completedAddrs), + zap.Float64("completedContracts/s", float64(c.completedAddrs)/elapsed), zap.Float64("time", elapsed), ) c.start = now c.size = 0 - c.addrCount = 0 + c.completedAddrs = 0 } } diff --git a/migration/headstate/ingestor.go b/migration/headstate/ingestor.go index cfe60388e4..a8df2739f9 100644 --- a/migration/headstate/ingestor.go +++ b/migration/headstate/ingestor.go @@ -43,12 +43,12 @@ func (c *ingestor) Run(index int, addr felt.Address, outputs chan<- task) error return err } if t.batch.Size() > sizeBefore { - t.addrCount++ + t.completedAddrs++ } if t.batch.Size() >= targetBatchByteSize { - outputs <- task{batch: t.batch, addrCount: t.addrCount} - t.addrCount = 0 + outputs <- task{batch: t.batch, completedAddrs: t.completedAddrs} + t.completedAddrs = 0 t.batch = c.batchSemaphore.GetBlocking() } return nil diff --git a/migration/headstate/migrator.go b/migration/headstate/migrator.go index 6f9b43c840..c669f8a20c 100644 --- a/migration/headstate/migrator.go +++ b/migration/headstate/migrator.go @@ -2,6 +2,7 @@ package headstate import ( "context" + "errors" "fmt" "iter" "time" @@ -24,8 +25,8 @@ const ( ) type task struct { - batch db.Batch - addrCount int + batch db.Batch + completedAddrs int } var ( @@ -33,6 +34,58 @@ var ( shouldNotRerun = []byte(nil) ) +var _ migration.Migration = (*Migrator)(nil) + +// Migrator consolidates the deprecated per-field contract layout into a +// single Contract record per address, written via state.WriteContract: +// +// ContractClassHash[addr] +// ContractNonce[addr] +// ContractDeploymentHeight[addr] +// │ +// ▼ +// Contract[addr] = { ClassHash, Nonce, DeployedHeight } +// +// StorageRoot is left zero — the running node lazily backfills it on the +// contract's first storage write. +// +// Each address discovered in the ContractClassHash bucket is processed by one +// of ingestorCount worker goroutines that read the three old fields into a +// shared db.Batch; a single committer drains batches to disk. Once every +// address has been migrated, the three deprecated buckets are wiped via +// DeleteRange. +// +// Re-run safe: an address whose Contract record already exists is skipped +// (via state.HasContract), and the trailing wipe re-issues DeleteRange over +// the (possibly already empty) ranges. +type Migrator struct{} + +func (Migrator) Before([]byte) error { + return nil +} + +func (Migrator) Migrate( + ctx context.Context, + database db.KeyValueStore, + _ *networks.Network, + logger log.StructuredLogger, +) ([]byte, error) { + addressesIter, sourceErr := pendingAddresses(database) + res := migrateAddresses(ctx, database, logger, addressesIter) + + if err := errors.Join(sourceErr(), res.Err); err != nil { + return shouldRerun, err + } + if !res.IsDone { + if ctxErr := ctx.Err(); ctxErr != nil { + return shouldRerun, ctxErr + } + return shouldRerun, errors.New("headstate migration did not complete") + } + + return shouldNotRerun, wipeDeprecatedBuckets(database) +} + func migrateAddresses( ctx context.Context, database db.KeyValueStore, @@ -64,51 +117,6 @@ func migrateAddresses( return wait() } -var _ migration.Migration = (*Migrator)(nil) - -type Migrator struct{} - -func (Migrator) Before([]byte) error { - return nil -} - -func (Migrator) Migrate( - ctx context.Context, - database db.KeyValueStore, - _ *networks.Network, - logger log.StructuredLogger, -) ([]byte, error) { - hasPending, err := hasPendingAddresses(database) - if err != nil { - return shouldRerun, err - } - if !hasPending { - return shouldNotRerun, wipeDeprecatedBuckets(database) - } - - addressesIter, sourceErr := pendingAddresses(database) - res := migrateAddresses(ctx, database, logger, addressesIter) - - if err := sourceErr(); err != nil { - return shouldRerun, err - } - if res.Err != nil || !res.IsDone { - return shouldRerun, res.Err - } - - return shouldNotRerun, wipeDeprecatedBuckets(database) -} - -func hasPendingAddresses(r db.KeyValueReader) (bool, error) { - prefix := db.ContractClassHash.Key() - it, err := r.NewIterator(prefix, true) - if err != nil { - return false, err - } - defer it.Close() - return it.First(), nil -} - func pendingAddresses(r db.KeyValueReader) (iter.Seq[felt.Address], func() error) { var iterErr error seq := func(yield func(felt.Address) bool) { diff --git a/migration/headstate/migrator_test.go b/migration/headstate/migrator_test.go index e8b281a727..99048754eb 100644 --- a/migration/headstate/migrator_test.go +++ b/migration/headstate/migrator_test.go @@ -189,7 +189,7 @@ func TestMigrate_Idempotent(t *testing.T) { } seedDeprecated(t, memDB, seeds) - for i := 0; i < 3; i++ { + for i := range 3 { res, err := headstate.Migrator{}.Migrate( context.Background(), memDB, From 17d07f2e4bfe91d10257bade878de84d42e5b6c0 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 19 May 2026 13:17:51 +0200 Subject: [PATCH 03/14] feat(migration): statehistory migration --- migration/statehistory/committer.go | 51 +++ migration/statehistory/counter.go | 55 +++ migration/statehistory/ingestor.go | 57 +++ migration/statehistory/migrator.go | 133 ++++++ migration/statehistory/migrator_test.go | 570 ++++++++++++++++++++++++ migration/statehistory/parse.go | 28 ++ migration/statehistory/transform.go | 259 +++++++++++ node/migration.go | 4 +- 8 files changed, 1156 insertions(+), 1 deletion(-) create mode 100644 migration/statehistory/committer.go create mode 100644 migration/statehistory/counter.go create mode 100644 migration/statehistory/ingestor.go create mode 100644 migration/statehistory/migrator.go create mode 100644 migration/statehistory/migrator_test.go create mode 100644 migration/statehistory/parse.go create mode 100644 migration/statehistory/transform.go diff --git a/migration/statehistory/committer.go b/migration/statehistory/committer.go new file mode 100644 index 0000000000..4f409b9352 --- /dev/null +++ b/migration/statehistory/committer.go @@ -0,0 +1,51 @@ +package statehistory + +import ( + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/migration/pipeline" + "github.com/NethermindEth/juno/migration/semaphore" + "github.com/NethermindEth/juno/utils/log" + "go.uber.org/zap" +) + +type committer struct { + logger log.StructuredLogger + counter counter + batchSemaphore semaphore.ResourceSemaphore[db.Batch] +} + +var _ pipeline.State[task, struct{}] = (*committer)(nil) + +func newCommitter( + logger log.StructuredLogger, + batchSemaphore semaphore.ResourceSemaphore[db.Batch], + phaseName string, +) *committer { + return &committer{ + logger: logger, + counter: newCounter(logger, timeLogRate, phaseName), + batchSemaphore: batchSemaphore, + } +} + +func (c *committer) Run(_ int, t task, _ chan<- struct{}) error { + c.logger.Debug( + "writing batch", + zap.Int("addrCount", t.addrCount), + zap.Int("entryCount", t.entryCount), + zap.Int("batchSize", t.batch.Size()), + ) + + byteSize := uint64(t.batch.Size()) + if err := t.batch.Write(); err != nil { + return err + } + + c.counter.log(byteSize, t.addrCount, t.entryCount) + c.batchSemaphore.Put() + return nil +} + +func (c *committer) Done(int, chan<- struct{}) error { + return nil +} diff --git a/migration/statehistory/counter.go b/migration/statehistory/counter.go new file mode 100644 index 0000000000..d5b10f26d8 --- /dev/null +++ b/migration/statehistory/counter.go @@ -0,0 +1,55 @@ +package statehistory + +import ( + "time" + + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/utils/log" + "go.uber.org/zap" +) + +type counter struct { + logger log.StructuredLogger + timeLogRate time.Duration + phaseName string + start time.Time + size uint64 + addrCount uint64 + entryCount uint64 +} + +func newCounter(logger log.StructuredLogger, timeLogRate time.Duration, phaseName string) counter { + return counter{ + logger: logger, + timeLogRate: timeLogRate, + phaseName: phaseName, + start: time.Now(), + } +} + +func (c *counter) log(byteSize uint64, addrCount, entryCount int) { + c.size += byteSize + c.addrCount += uint64(addrCount) + c.entryCount += uint64(entryCount) + + now := time.Now() + elapsed := now.Sub(c.start).Seconds() + if elapsed > float64(c.timeLogRate.Seconds()) { + mbs := float64(c.size) / float64(db.Megabyte) + c.logger.Info( + "write speed", + zap.String("phase", c.phaseName), + zap.Float64("MB", mbs), + zap.Float64("MB/s", mbs/elapsed), + zap.Uint64("contracts", c.addrCount), + zap.Float64("contracts/s", float64(c.addrCount)/elapsed), + zap.Uint64("entries", c.entryCount), + zap.Float64("entries/s", float64(c.entryCount)/elapsed), + zap.Float64("time", elapsed), + ) + c.start = now + c.size = 0 + c.addrCount = 0 + c.entryCount = 0 + } +} diff --git a/migration/statehistory/ingestor.go b/migration/statehistory/ingestor.go new file mode 100644 index 0000000000..f58e825a8d --- /dev/null +++ b/migration/statehistory/ingestor.go @@ -0,0 +1,57 @@ +package statehistory + +import ( + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/migration/pipeline" + "github.com/NethermindEth/juno/migration/semaphore" +) + +type FlushBatchFn func(t *task) + +type ingestor struct { + batchSemaphore semaphore.ResourceSemaphore[db.Batch] + database db.KeyValueReader + tasks []task + transform func(db.KeyValueReader, *task, felt.Address, FlushBatchFn) error +} + +var _ pipeline.State[felt.Address, task] = (*ingestor)(nil) + +func newIngestor( + sem semaphore.ResourceSemaphore[db.Batch], + database db.KeyValueReader, + transform func(db.KeyValueReader, *task, felt.Address, FlushBatchFn) error, +) *ingestor { + tasks := make([]task, ingestorCount) + for i := range tasks { + tasks[i] = task{batch: sem.GetBlocking()} + } + return &ingestor{batchSemaphore: sem, database: database, tasks: tasks, transform: transform} +} + +func (p *ingestor) Run(index int, addr felt.Address, outputs chan<- task) error { + t := &p.tasks[index] + flush := func(t *task) { + if t.batch.Size() < targetBatchByteSize { + return + } + outputs <- *t + *t = task{batch: p.batchSemaphore.GetBlocking()} + } + + if err := p.transform(p.database, t, addr, flush); err != nil { + return err + } + + if t.batch.Size() >= targetBatchByteSize { + outputs <- *t + *t = task{batch: p.batchSemaphore.GetBlocking()} + } + return nil +} + +func (p *ingestor) Done(index int, outputs chan<- task) error { + outputs <- p.tasks[index] + return nil +} diff --git a/migration/statehistory/migrator.go b/migration/statehistory/migrator.go new file mode 100644 index 0000000000..cc3eed6386 --- /dev/null +++ b/migration/statehistory/migrator.go @@ -0,0 +1,133 @@ +package statehistory + +import ( + "context" + "fmt" + "time" + + "github.com/NethermindEth/juno/blockchain/networks" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/migration" + "github.com/NethermindEth/juno/migration/pipeline" + "github.com/NethermindEth/juno/migration/semaphore" + "github.com/NethermindEth/juno/utils/log" +) + +const ( + batchByteSize = 128 * db.Megabyte + targetBatchByteSize = 96 * db.Megabyte + ingestorCount = 4 + timeLogRate = 5 * time.Second +) + +type task struct { + batch db.Batch + addrCount int + entryCount int +} + +var ( + shouldRerun = []byte{} + shouldNotRerun = []byte(nil) +) + +var _ migration.Migration = (*Migrator)(nil) + +type Migrator struct{} + +func (Migrator) Before([]byte) error { return nil } + +func (Migrator) Migrate( + ctx context.Context, + database db.KeyValueStore, + _ *networks.Network, + logger log.StructuredLogger, +) ([]byte, error) { + addresses, err := collectAddresses(database) + if err != nil { + return shouldRerun, err + } + if len(addresses) == 0 { + logger.Info("state history migration: Contract bucket empty, marking applied") + return shouldNotRerun, nil + } + + if err := runPhase( + ctx, + database, + logger, + addresses, + "class-hash", + TransformClassHashHistory, + ); err != nil { + return shouldRerun, err + } + if err := runPhase(ctx, database, logger, addresses, "nonce", TransformNonceHistory); err != nil { + return shouldRerun, err + } + if err := runPhase( + ctx, + database, + logger, + addresses, + "storage", + TransformStorageHistory, + ); err != nil { + return shouldRerun, err + } + + return shouldNotRerun, nil +} + +func runPhase( + ctx context.Context, + database db.KeyValueStore, + logger log.StructuredLogger, + addresses []felt.Address, + name string, + transform func(db.KeyValueReader, *task, felt.Address, FlushBatchFn) error, +) error { + sem := semaphore.New(ingestorCount+1, func() db.Batch { + return database.NewBatchWithSize(batchByteSize) + }) + src := pipeline.Source(func(yield func(felt.Address) bool) { + for _, a := range addresses { + if !yield(a) { + return + } + } + }) + ingestors := pipeline.New(src, ingestorCount, newIngestor(sem, database, transform)) + committers := pipeline.New(ingestors, 1, newCommitter(logger, sem, name)) + + _, wait := committers.Run(ctx) + res := wait() + if res.Err != nil { + return res.Err + } + if !res.IsDone { + return fmt.Errorf("%s phase did not complete", name) + } + return nil +} + +func collectAddresses(r db.KeyValueReader) ([]felt.Address, error) { + prefix := db.Contract.Key() + it, err := r.NewIterator(prefix, true) + if err != nil { + return nil, err + } + defer it.Close() + + var addrs []felt.Address + for valid := it.First(); valid; valid = it.Next() { + key := it.Key() + if len(key) != len(prefix)+felt.Bytes { + continue + } + f := felt.FromBytes[felt.Felt](key[len(prefix):]) + addrs = append(addrs, felt.Address(f)) + } + return addrs, nil +} diff --git a/migration/statehistory/migrator_test.go b/migration/statehistory/migrator_test.go new file mode 100644 index 0000000000..30d9978f94 --- /dev/null +++ b/migration/statehistory/migrator_test.go @@ -0,0 +1,570 @@ +package statehistory + +import ( + "context" + "testing" + + "github.com/NethermindEth/juno/blockchain/networks" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/deprecatedstate" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2/triedb" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/NethermindEth/juno/utils/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func seedContract( + t *testing.T, + memDB db.KeyValueStore, + addr felt.Felt, + nonce, classHash felt.Felt, + deployHeight uint64, +) { + t.Helper() + require.NoError(t, state.WriteContract(memDB, &addr, nonce, classHash, deployHeight)) +} + +func seedDeprecatedClassHashHistory( + t *testing.T, + w db.KeyValueWriter, + addr felt.Felt, + block uint64, + oldValue felt.Felt, +) { + t.Helper() + require.NoError(t, core.WriteDeprecatedContractClassHashHistory(w, &addr, &oldValue, block)) +} + +func seedDeprecatedNonceHistory( + t *testing.T, + w db.KeyValueWriter, + addr felt.Felt, + block uint64, + oldValue felt.Felt, +) { + t.Helper() + require.NoError(t, core.WriteDeprecatedContractNonceHistory(w, &addr, &oldValue, block)) +} + +func seedDeprecatedStorageHistory( + t *testing.T, + w db.KeyValueWriter, + addr, slot felt.Felt, + block uint64, + oldValue felt.Felt, +) { + t.Helper() + require.NoError(t, core.WriteDeprecatedContractStorageHistory(w, &addr, &slot, &oldValue, block)) +} + +// seedDeprecatedStorageTrie populates the deprecated ContractStorage trie for +// `addr` with the given (slot -> value) leaves, so the storage phase can read +// head values. +func seedDeprecatedStorageTrie( + t *testing.T, + memDB db.KeyValueStore, + addr felt.Felt, + leaves map[felt.Felt]felt.Felt, +) { + t.Helper() + //nolint:staticcheck // Necessary for old state + txn := memDB.NewIndexedBatch() + tr, err := trie.NewTriePedersen( + txn, + db.ContractStorage.Key(addr.Marshal()), + deprecatedstate.ContractStorageTrieHeight, + ) + require.NoError(t, err) + for k, v := range leaves { + _, err := tr.Put(&k, &v) + require.NoError(t, err) + } + require.NoError(t, tr.Commit()) + require.NoError(t, txn.Write()) +} + +func bucketKeyCount(t *testing.T, r db.KeyValueReader, bucket db.Bucket) int { + t.Helper() + it, err := r.NewIterator(bucket.Key(), true) + require.NoError(t, err) + defer it.Close() + count := 0 + for valid := it.First(); valid; valid = it.Next() { + count++ + } + return count +} + +// ----- Tests ----- + +func TestMigrate_EmptyDB(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) +} + +// Class-hash phase: a contract that was never reclassed has no deprecated +// entries. After migration, history has exactly one entry: +// (addr, deploy_height) -> ClassHash. +func TestMigrate_ClassHash_DeployOnly(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + classHash := felt.FromUint64[felt.Felt](170) + seedContract(t, memDB, addr, felt.Zero, classHash, 100) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + got, err := reader.ContractClassHashAt(&addr, 100) + require.NoError(t, err) + assert.Equal(t, classHash, got) + + assert.Equal( + t, + 0, + bucketKeyCount(t, memDB, db.DeprecatedContractClassHashHistory), + "deprecated must be empty", + ) +} + +// Class-hash phase: a contract reclassed once. The deprecated bucket has one +// entry holding the deploy class hash (the value before the replace). After +// migration, history has: (addr, deploy_height) -> deploy class hash, and +// (addr, replace_block) -> replaced class hash. +func TestMigrate_ClassHash_Reclassed(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + deployClass := felt.FromUint64[felt.Felt](170) + replacedClass := felt.FromUint64[felt.Felt](187) + deployHeight := uint64(100) + replaceBlock := uint64(300) + + seedContract(t, memDB, addr, felt.Zero, replacedClass, deployHeight) + seedDeprecatedClassHashHistory(t, memDB, addr, replaceBlock, deployClass) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + got, err := reader.ContractClassHashAt(&addr, deployHeight) + require.NoError(t, err) + assert.Equal(t, deployClass, got, "deploy entry preserved") + + got, err = reader.ContractClassHashAt(&addr, replaceBlock) + require.NoError(t, err) + assert.Equal(t, replacedClass, got, "replace block has post-update value (head)") + + assert.Equal( + t, + 0, + bucketKeyCount(t, memDB, db.DeprecatedContractClassHashHistory), + "deprecated must be empty", + ) +} + +// Nonce phase: contract whose nonce was updated multiple times. +// Old: (B_1, 0), (B_2, n_1). New: (B_1, n_1), (B_2, head). +func TestMigrate_Nonce_Updated(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + classHash := felt.FromUint64[felt.Felt](170) + headNonce := felt.FromUint64[felt.Felt](66) + deployHeight := uint64(100) + + seedContract(t, memDB, addr, headNonce, classHash, deployHeight) + seedDeprecatedNonceHistory(t, memDB, addr, 200, felt.Zero) + seedDeprecatedNonceHistory(t, memDB, addr, 300, felt.FromUint64[felt.Felt](1)) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + got, err := reader.ContractNonceAt(&addr, 200) + require.NoError(t, err) + assert.Equal( + t, + felt.FromUint64[felt.Felt](1), + got, + "value installed at 200 = next entry's old value", + ) + + got, err = reader.ContractNonceAt(&addr, 300) + require.NoError(t, err) + assert.Equal(t, headNonce, got, "value installed at 300 = head") + + assert.Equal( + t, + 0, + bucketKeyCount(t, memDB, db.DeprecatedContractNonceHistory), + "deprecated must be empty", + ) +} + +// Nonce phase: deploy-only contract (nonce never updated). No deprecated +// entries. Migration is a no-op for this address; history stays empty for it. +func TestMigrate_Nonce_DeployOnly(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](170), 100) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + assert.Equal( + t, + 0, + bucketKeyCount(t, memDB, db.ContractNonceHistory), + "no nonce entry expected when never updated", + ) +} + +// Storage phase: a slot with multi-write history. +// Old has (B_1, 0), (B_2, firstVal), (B_3, secondVal). +// Head from old trie = headVal. New has (B_1, firstVal), (B_2, secondVal), (B_3, headVal). +func TestMigrate_Storage_MultiWrite(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + slot := felt.FromUint64[felt.Felt](170) + firstVal := felt.FromUint64[felt.Felt](5) + secondVal := felt.FromUint64[felt.Felt](12) + headVal := felt.FromUint64[felt.Felt](7) + + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](187), 100) + + seedDeprecatedStorageHistory(t, memDB, addr, slot, 100, felt.Zero) + seedDeprecatedStorageHistory(t, memDB, addr, slot, 200, firstVal) + seedDeprecatedStorageHistory(t, memDB, addr, slot, 300, secondVal) + + seedDeprecatedStorageTrie(t, memDB, addr, map[felt.Felt]felt.Felt{slot: headVal}) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + got, err := reader.ContractStorageAt(&addr, &slot, 100) + require.NoError(t, err) + assert.Equal(t, firstVal, got, "value installed at 100") + + got, err = reader.ContractStorageAt(&addr, &slot, 200) + require.NoError(t, err) + assert.Equal(t, secondVal, got, "value installed at 200") + + got, err = reader.ContractStorageAt(&addr, &slot, 300) + require.NoError(t, err) + assert.Equal(t, headVal, got, "value installed at 300 = head from trie") + + assert.Equal( + t, + 0, + bucketKeyCount(t, memDB, db.DeprecatedContractStorageHistory), + "deprecated must be empty", + ) +} + +// Storage phase: a slot with one write at deploy block. Old: (B_1, 0). Head from trie = v. +// New: (B_1, v). +func TestMigrate_Storage_SingleWrite(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + slot := felt.FromUint64[felt.Felt](170) + v := felt.FromUint64[felt.Felt](9) + + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](187), 100) + seedDeprecatedStorageHistory(t, memDB, addr, slot, 100, felt.Zero) + seedDeprecatedStorageTrie(t, memDB, addr, map[felt.Felt]felt.Felt{slot: v}) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + got, err := reader.ContractStorageAt(&addr, &slot, 100) + require.NoError(t, err) + assert.Equal(t, v, got) + + assert.Equal(t, 0, bucketKeyCount(t, memDB, db.DeprecatedContractStorageHistory)) +} + +// Idempotency: running the migration twice produces the same history result +// and doesn't corrupt anything. +func TestMigrate_Idempotent(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + deployClass := felt.FromUint64[felt.Felt](170) + replacedClass := felt.FromUint64[felt.Felt](187) + + seedContract(t, memDB, addr, felt.FromUint64[felt.Felt](66), replacedClass, 100) + seedDeprecatedClassHashHistory(t, memDB, addr, 300, deployClass) + seedDeprecatedNonceHistory(t, memDB, addr, 200, felt.Zero) + + for i := 0; i < 3; i++ { + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + } + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + got, err := reader.ContractClassHashAt(&addr, 100) + require.NoError(t, err) + assert.Equal(t, deployClass, got) + + got, err = reader.ContractClassHashAt(&addr, 300) + require.NoError(t, err) + assert.Equal(t, replacedClass, got) + + got, err = reader.ContractNonceAt(&addr, 200) + require.NoError(t, err) + assert.Equal(t, felt.FromUint64[felt.Felt](66), got) +} + +// Models a crash between writing some history entries and the final +// DeleteRange of the deprecated bucket for class-hash: history has the deploy +// entry written, the deprecated bucket is still fully intact. The migration +// must reach the same final state as a clean run. +func TestMigrate_ClassHash_ResumeFromPartial(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + deployClass := felt.FromUint64[felt.Felt](170) + replacedClass := felt.FromUint64[felt.Felt](187) + deployHeight := uint64(100) + replaceBlock := uint64(300) + + seedContract(t, memDB, addr, felt.Zero, replacedClass, deployHeight) + seedDeprecatedClassHashHistory(t, memDB, addr, replaceBlock, deployClass) + + // Simulate a prior partial run: deploy entry already written to history, + // deprecated bucket still fully populated, DeleteRange not yet executed. + require.NoError(t, state.WriteClassHashHistory(memDB, &addr, deployHeight, &deployClass)) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + got, err := reader.ContractClassHashAt(&addr, deployHeight) + require.NoError(t, err) + assert.Equal(t, deployClass, got, "deploy entry must survive partial-resume re-run") + + got, err = reader.ContractClassHashAt(&addr, replaceBlock) + require.NoError(t, err) + assert.Equal(t, replacedClass, got, "shifted entry at replace block") + + assert.Equal( + t, + 0, + bucketKeyCount(t, memDB, db.DeprecatedContractClassHashHistory), + "deprecated must be empty after migration completes", + ) +} + +// Pedersen storage trie omits zero-valued leaves. A slot that has deprecated +// history (was once written) but whose current value is zero has no leaf in +// the trie. The lockstep iteration must surface head=Zero for such slots. +func TestMigrate_Storage_ZeroedSlotHasNoLeaf(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + zeroedSlot := felt.FromUint64[felt.Felt](170) + keptSlot := felt.FromUint64[felt.Felt](187) + keptHead := felt.FromUint64[felt.Felt](9) + + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](204), 100) + + // Both slots have deprecated entries... + seedDeprecatedStorageHistory(t, memDB, addr, zeroedSlot, 100, felt.Zero) + seedDeprecatedStorageHistory(t, memDB, addr, zeroedSlot, 200, felt.FromUint64[felt.Felt](5)) + seedDeprecatedStorageHistory(t, memDB, addr, keptSlot, 100, felt.Zero) + + // ...but only keptSlot has a leaf in the trie. zeroedSlot's current value + // is 0 → Pedersen trie didn't store a leaf for it. + seedDeprecatedStorageTrie(t, memDB, addr, map[felt.Felt]felt.Felt{keptSlot: keptHead}) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + // keptSlot's last entry should be the head from the trie. + got, err := reader.ContractStorageAt(&addr, &keptSlot, 100) + require.NoError(t, err) + assert.Equal(t, keptHead, got, "kept slot last entry = head") + + // zeroedSlot's last entry should be 0 (no leaf in trie). + got, err = reader.ContractStorageAt(&addr, &zeroedSlot, 200) + require.NoError(t, err) + assert.Equal(t, felt.Zero, got, "zeroed slot last entry = Zero (no leaf in trie)") + + // First entry of zeroedSlot's deprecated history was at block 100 with + // v_1=0; new layout at block 100 should be v_2 = 0x5 (next entry's value). + got, err = reader.ContractStorageAt(&addr, &zeroedSlot, 100) + require.NoError(t, err) + assert.Equal(t, felt.FromUint64[felt.Felt](5), got, "shift-up at block 100 = next-entry's value") + + assert.Equal( + t, + 0, + bucketKeyCount(t, memDB, db.DeprecatedContractStorageHistory), + "deprecated fully drained", + ) +} + +// Storage at scale: many slots × many writes per slot for a single address. +// Verifies the per-slot grouping plus per-address DeleteRange handles +// non-trivial volumes correctly. Not large enough to hit the 96MB batch +// threshold (memory test runtimes would be silly), but exercises the +// grouping + cleanup logic across hundreds of slot boundaries. +func TestMigrate_Storage_ManyEntries(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + headValues := map[felt.Felt]felt.Felt{} + + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](187), 100) + + const ( + numSlots = 50 + numEntriesPerSlot = 20 + startBlock = uint64(100) + ) + + for s := uint64(1); s <= numSlots; s++ { + slot := felt.NewFromUint64[felt.Felt](s) + headVal := felt.NewFromUint64[felt.Felt](1000000 + s) + headValues[*slot] = *headVal + + for b := uint64(0); b < numEntriesPerSlot; b++ { + block := startBlock + b + oldVal := felt.Zero + if b > 0 { + oldVal = *felt.NewFromUint64[felt.Felt](s*10000 + b) + } + seedDeprecatedStorageHistory(t, memDB, addr, *slot, block, oldVal) + } + } + seedDeprecatedStorageTrie(t, memDB, addr, headValues) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + assert.Equal( + t, + 0, + bucketKeyCount(t, memDB, db.DeprecatedContractStorageHistory), + "deprecated storage history must be fully drained", + ) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + // Sample a few slots: the last block per slot must equal head; intermediate + // blocks must equal the next deprecated entry's old value. + for slot, head := range headValues { + lastBlock := startBlock + numEntriesPerSlot - 1 + got, err := reader.ContractStorageAt(&addr, &slot, lastBlock) + require.NoErrorf(t, err, "read storage failed for slot %v", &slot) + assert.Equalf(t, head, got, "last entry must equal head for slot %v", &slot) + } +} diff --git a/migration/statehistory/parse.go b/migration/statehistory/parse.go new file mode 100644 index 0000000000..196c9b6150 --- /dev/null +++ b/migration/statehistory/parse.go @@ -0,0 +1,28 @@ +package statehistory + +import ( + "encoding/binary" + "fmt" + + "github.com/NethermindEth/juno/core/felt" +) + +func parseBlockKey(key, prefix []byte) (uint64, error) { + if len(key) != len(prefix)+8 { + return 0, fmt.Errorf("malformed block-keyed entry: key len %d, want %d", len(key), len(prefix)+8) + } + return binary.BigEndian.Uint64(key[len(prefix):]), nil +} + +func parseStorageKey(key, prefix []byte) (felt.Felt, uint64, error) { + if len(key) != len(prefix)+felt.Bytes+8 { + return felt.Felt{}, 0, fmt.Errorf( + "malformed storage-history entry: key len %d, want %d", + len(key), + len(prefix)+felt.Bytes+8, + ) + } + slot := felt.FromBytes[felt.Felt](key[len(prefix) : len(prefix)+felt.Bytes]) + block := binary.BigEndian.Uint64(key[len(prefix)+felt.Bytes:]) + return slot, block, nil +} diff --git a/migration/statehistory/transform.go b/migration/statehistory/transform.go new file mode 100644 index 0000000000..b8352d31a4 --- /dev/null +++ b/migration/statehistory/transform.go @@ -0,0 +1,259 @@ +package statehistory + +import ( + "bytes" + "fmt" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/dbutils" +) + +func TransformClassHashHistory( + database db.KeyValueReader, + t *task, + addr felt.Address, + flush FlushBatchFn, +) error { + addrFelt := felt.Felt(addr) + deprecatedPrefix := db.DeprecatedContractClassHashHistoryKey(&addrFelt) + + contract, err := state.GetContract(database, &addrFelt) + if err != nil { + return fmt.Errorf("class-hash: GetContract(%s): %w", &addrFelt, err) + } + + deployKey := db.ContractClassHashHistoryAtBlockKey(&addrFelt, contract.DeployedHeight) + deployEntryExists, err := database.Has(deployKey) + if err != nil { + return fmt.Errorf("class-hash: Has(deploy entry): %w", err) + } + + depIt, err := database.NewIterator(deprecatedPrefix, true) + if err != nil { + return fmt.Errorf("class-hash: open deprecated iter(%s): %w", &addrFelt, err) + } + defer depIt.Close() + + if !depIt.First() { + if deployEntryExists { + return nil + } + if err := state.WriteClassHashHistory( + t.batch, + &addrFelt, + contract.DeployedHeight, + &contract.ClassHash, + ); err != nil { + return err + } + t.addrCount++ + t.entryCount++ + flush(t) + return nil + } + + rawValue, err := depIt.Value() + if err != nil { + return fmt.Errorf("class-hash: read first value(%s): %w", &addrFelt, err) + } + deployClassHash := felt.FromBytes[felt.Felt](rawValue) + if err := state.WriteClassHashHistory( + t.batch, + &addrFelt, + contract.DeployedHeight, + &deployClassHash, + ); err != nil { + return err + } + t.entryCount++ + flush(t) + + if _, err := shiftUpHistoryEntries( + depIt, + deprecatedPrefix, + &addrFelt, + contract.ClassHash, + t, + flush, + state.WriteClassHashHistory, + ); err != nil { + return fmt.Errorf("class-hash(%s): %w", &addrFelt, err) + } + + if err := t.batch.DeleteRange(deprecatedPrefix, dbutils.UpperBound(deprecatedPrefix)); err != nil { + return fmt.Errorf("class-hash: DeleteRange deprecated(%s): %w", &addrFelt, err) + } + t.addrCount++ + + return nil +} + +func TransformNonceHistory( + database db.KeyValueReader, + t *task, + addr felt.Address, + flush FlushBatchFn, +) error { + addrFelt := felt.Felt(addr) + deprecatedPrefix := db.DeprecatedContractNonceHistoryKey(&addrFelt) + + depIt, err := database.NewIterator(deprecatedPrefix, true) + if err != nil { + return fmt.Errorf("nonce: open deprecated iter(%s): %w", &addrFelt, err) + } + defer depIt.Close() + if !depIt.First() { + return nil + } + + contract, err := state.GetContract(database, &addrFelt) + if err != nil { + return fmt.Errorf("nonce: GetContract(%s): %w", &addrFelt, err) + } + + if _, err := shiftUpHistoryEntries( + depIt, + deprecatedPrefix, + &addrFelt, + contract.Nonce, + t, + flush, + state.WriteNonceHistory, + ); err != nil { + return fmt.Errorf("nonce(%s): %w", &addrFelt, err) + } + + if err := t.batch.DeleteRange(deprecatedPrefix, dbutils.UpperBound(deprecatedPrefix)); err != nil { + return fmt.Errorf("nonce: DeleteRange deprecated(%s): %w", &addrFelt, err) + } + t.addrCount++ + + return nil +} + +func TransformStorageHistory( + database db.KeyValueReader, + t *task, + addr felt.Address, + flush FlushBatchFn) error { + addrFelt := felt.Felt(addr) + addrBytes := addrFelt.Marshal() + deprecatedPrefix := db.DeprecatedContractStorageHistory.Key(addrBytes) + + deprecatedHistoryIt, err := database.NewIterator(deprecatedPrefix, true) + if err != nil { + return fmt.Errorf("storage: open deprecated iter(%s): %w", &addrFelt, err) + } + defer deprecatedHistoryIt.Close() + if !deprecatedHistoryIt.First() { + return nil + } + + leafPrefix := db.ContractStorage.Key(addrBytes) + leafPrefix = append(leafPrefix, 251) + + headStorageTrieIt, err := database.NewIterator(leafPrefix, true) + if err != nil { + return fmt.Errorf("storage: open leaf iter(%s): %w", &addrFelt, err) + } + defer headStorageTrieIt.Close() + leafValid := headStorageTrieIt.First() + + for { + slot, block, err := parseStorageKey(deprecatedHistoryIt.Key(), deprecatedPrefix) + if err != nil { + return fmt.Errorf("storage: parse key(%s): %w", &addrFelt, err) + } + + hasNext := deprecatedHistoryIt.Next() + var successorSlot, successorValue felt.Felt + if hasNext { + s, _, err := parseStorageKey(deprecatedHistoryIt.Key(), deprecatedPrefix) + if err != nil { + return fmt.Errorf("storage: parse successor key(%s): %w", &addrFelt, err) + } + successorSlot = s + rawValue, err := deprecatedHistoryIt.Value() + if err != nil { + return fmt.Errorf("storage: read successor value(%s, slot=%s): %w", &addrFelt, &s, err) + } + successorValue = felt.FromBytes[felt.Felt](rawValue) + } + + var historyValue felt.Felt + switch { + case hasNext && successorSlot == slot: + historyValue = successorValue + case leafValid: + leafSlot := headStorageTrieIt.Key()[len(leafPrefix):] + if bytes.Equal(leafSlot, slot.Marshal()) { + raw, err := headStorageTrieIt.Value() + if err != nil { + return fmt.Errorf("storage: leaf(%s, slot=%s): %w", &addrFelt, &slot, err) + } + var node trie.Node + if err := node.UnmarshalBinary(raw); err != nil { + return fmt.Errorf("storage: decode leaf(%s, slot=%s): %w", &addrFelt, &slot, err) + } + historyValue = *node.Value + leafValid = headStorageTrieIt.Next() + } + } + + if err := state.WriteStorageHistory(t.batch, &addrFelt, &slot, block, &historyValue); err != nil { + return err + } + t.entryCount++ + flush(t) + + if !hasNext { + break + } + } + + if err := t.batch.DeleteRange(deprecatedPrefix, dbutils.UpperBound(deprecatedPrefix)); err != nil { + return fmt.Errorf("storage: DeleteRange deprecated(%s): %w", &addrFelt, err) + } + t.addrCount++ + + return nil +} + +func shiftUpHistoryEntries( + depIt db.Iterator, + deprecatedPrefix []byte, + addr *felt.Felt, + headValue felt.Felt, + t *task, + flush FlushBatchFn, + write func(db.KeyValueWriter, *felt.Felt, uint64, *felt.Felt) error, +) (int, error) { + entryCount := 0 + for { + block, err := parseBlockKey(depIt.Key(), deprecatedPrefix) + if err != nil { + return 0, err + } + hasNext := depIt.Next() + historyValue := headValue + if hasNext { + rawValue, err := depIt.Value() + if err != nil { + return 0, err + } + historyValue = felt.FromBytes[felt.Felt](rawValue) + } + if err := write(t.batch, addr, block, &historyValue); err != nil { + return 0, err + } + entryCount++ + t.entryCount++ + flush(t) + if !hasNext { + return entryCount, nil + } + } +} diff --git a/node/migration.go b/node/migration.go index 71f387597f..57a53e9eac 100644 --- a/node/migration.go +++ b/node/migration.go @@ -13,6 +13,7 @@ import ( "github.com/NethermindEth/juno/migration/deprecated" //nolint:staticcheck,nolintlint,lll // ignore statick check package will be removed in future, nolinlint because main config does not check "github.com/NethermindEth/juno/migration/headstate" "github.com/NethermindEth/juno/migration/historyprunner" + "github.com/NethermindEth/juno/migration/statehistory" "github.com/NethermindEth/juno/utils/log" ) @@ -27,7 +28,8 @@ func registerMigrations(cfg *Config) *migration.Registry { cfg.Prune, PruneModeFlag, ). - WithOptional(&headstate.Migrator{}, cfg.NewState, "new-state") + WithOptional(&headstate.Migrator{}, cfg.NewState, "new-state"). + WithOptional(&statehistory.Migrator{}, cfg.NewState, "new-state") return registry } From 405bb23a3ee51fa369aa800ce53481e9b6211de2 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 19 May 2026 23:16:24 +0200 Subject: [PATCH 04/14] chore: self review --- migration/statehistory/class_hash_ingestor.go | 120 +++++++ migration/statehistory/committer.go | 7 +- migration/statehistory/counter.go | 26 +- migration/statehistory/ingestor.go | 66 ++-- migration/statehistory/migrator.go | 141 ++++---- migration/statehistory/migrator_test.go | 313 ++++++++++++++---- migration/statehistory/nonce_ingestor.go | 79 +++++ migration/statehistory/storage_ingestor.go | 175 ++++++++++ migration/statehistory/transform.go | 259 --------------- 9 files changed, 761 insertions(+), 425 deletions(-) create mode 100644 migration/statehistory/class_hash_ingestor.go create mode 100644 migration/statehistory/nonce_ingestor.go create mode 100644 migration/statehistory/storage_ingestor.go delete mode 100644 migration/statehistory/transform.go diff --git a/migration/statehistory/class_hash_ingestor.go b/migration/statehistory/class_hash_ingestor.go new file mode 100644 index 0000000000..0aea360b29 --- /dev/null +++ b/migration/statehistory/class_hash_ingestor.go @@ -0,0 +1,120 @@ +package statehistory + +import ( + "context" + "fmt" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/dbutils" + "github.com/NethermindEth/juno/migration/pipeline" + "github.com/NethermindEth/juno/migration/semaphore" +) + +type classHashIngestor struct { + baseIngestor +} + +var _ pipeline.State[*felt.Felt, task] = (*classHashIngestor)(nil) + +func newClassHashIngestor( + ctx context.Context, + sem semaphore.ResourceSemaphore[db.Batch], + database db.KeyValueReader, +) *classHashIngestor { + return &classHashIngestor{baseIngestor: newBaseIngestor(ctx, sem, database)} +} + +func (i *classHashIngestor) Run(index int, addr *felt.Felt, outputs chan<- task) error { + t := &i.tasks[index] + + deprecatedPrefix := db.DeprecatedContractClassHashHistoryKey(addr) + contract, err := state.GetContract(i.database, addr) + if err != nil { + return fmt.Errorf("class-hash: GetContract(%s): %w", addr, err) + } + + deployKey := db.ContractClassHashHistoryAtBlockKey(addr, contract.DeployedHeight) + deployEntryExists, err := i.database.Has(deployKey) + if err != nil { + return fmt.Errorf("class-hash: Has(deploy entry): %w", err) + } + + depIt, err := i.database.NewIterator(deprecatedPrefix, true) + if err != nil { + return fmt.Errorf("class-hash: open deprecated iter(%s): %w", addr, err) + } + defer depIt.Close() + + if !depIt.First() { + if deployEntryExists { + return nil + } + err = state.WriteClassHashHistory( + t.batch, + addr, + contract.DeployedHeight, + &contract.ClassHash, + ) + if err != nil { + return err + } + t.completedAddrs++ + t.entryCount++ + return i.flush(t, outputs) + } + + rawValue, err := depIt.Value() + if err != nil { + return fmt.Errorf("class-hash: read first value(%s): %w", addr, err) + } + deployClassHash := felt.FromBytes[felt.Felt](rawValue) + if err := state.WriteClassHashHistory( + t.batch, + addr, + contract.DeployedHeight, + &deployClassHash, + ); err != nil { + return err + } + t.entryCount++ + if err := i.flush(t, outputs); err != nil { + return err + } + + // Shift-up loop: each block in the deprecated history gets the *next* + // entry's value (since in the old layout the value at block B was the + // value before B's write). The final block gets the head class hash. + for { + block, err := parseBlockKey(depIt.Key(), deprecatedPrefix) + if err != nil { + return fmt.Errorf("class-hash(%s): %w", addr, err) + } + hasNext := depIt.Next() + historyValue := contract.ClassHash + if hasNext { + rawValue, err := depIt.Value() + if err != nil { + return fmt.Errorf("class-hash(%s): %w", addr, err) + } + historyValue = felt.FromBytes[felt.Felt](rawValue) + } + if err := state.WriteClassHashHistory(t.batch, addr, block, &historyValue); err != nil { + return err + } + t.entryCount++ + if err := i.flush(t, outputs); err != nil { + return err + } + if !hasNext { + break + } + } + + if err := t.batch.DeleteRange(deprecatedPrefix, dbutils.UpperBound(deprecatedPrefix)); err != nil { + return fmt.Errorf("class-hash: DeleteRange deprecated(%s): %w", addr, err) + } + t.completedAddrs++ + return nil +} diff --git a/migration/statehistory/committer.go b/migration/statehistory/committer.go index 4f409b9352..8dbff6b03e 100644 --- a/migration/statehistory/committer.go +++ b/migration/statehistory/committer.go @@ -29,9 +29,11 @@ func newCommitter( } func (c *committer) Run(_ int, t task, _ chan<- struct{}) error { + defer c.batchSemaphore.Put() + c.logger.Debug( "writing batch", - zap.Int("addrCount", t.addrCount), + zap.Int("completedAddrs", t.completedAddrs), zap.Int("entryCount", t.entryCount), zap.Int("batchSize", t.batch.Size()), ) @@ -41,8 +43,7 @@ func (c *committer) Run(_ int, t task, _ chan<- struct{}) error { return err } - c.counter.log(byteSize, t.addrCount, t.entryCount) - c.batchSemaphore.Put() + c.counter.log(byteSize, t.completedAddrs, t.entryCount) return nil } diff --git a/migration/statehistory/counter.go b/migration/statehistory/counter.go index d5b10f26d8..82e2aff9d0 100644 --- a/migration/statehistory/counter.go +++ b/migration/statehistory/counter.go @@ -9,13 +9,13 @@ import ( ) type counter struct { - logger log.StructuredLogger - timeLogRate time.Duration - phaseName string - start time.Time - size uint64 - addrCount uint64 - entryCount uint64 + logger log.StructuredLogger + timeLogRate time.Duration + phaseName string + start time.Time + size uint64 + completedAddrs uint64 + entryCount uint64 } func newCounter(logger log.StructuredLogger, timeLogRate time.Duration, phaseName string) counter { @@ -27,29 +27,29 @@ func newCounter(logger log.StructuredLogger, timeLogRate time.Duration, phaseNam } } -func (c *counter) log(byteSize uint64, addrCount, entryCount int) { +func (c *counter) log(byteSize uint64, completedAddrs, entryCount int) { c.size += byteSize - c.addrCount += uint64(addrCount) + c.completedAddrs += uint64(completedAddrs) c.entryCount += uint64(entryCount) now := time.Now() elapsed := now.Sub(c.start).Seconds() - if elapsed > float64(c.timeLogRate.Seconds()) { + if elapsed > c.timeLogRate.Seconds() { mbs := float64(c.size) / float64(db.Megabyte) c.logger.Info( "write speed", zap.String("phase", c.phaseName), zap.Float64("MB", mbs), zap.Float64("MB/s", mbs/elapsed), - zap.Uint64("contracts", c.addrCount), - zap.Float64("contracts/s", float64(c.addrCount)/elapsed), + zap.Uint64("completedContracts", c.completedAddrs), + zap.Float64("completedContracts/s", float64(c.completedAddrs)/elapsed), zap.Uint64("entries", c.entryCount), zap.Float64("entries/s", float64(c.entryCount)/elapsed), zap.Float64("time", elapsed), ) c.start = now c.size = 0 - c.addrCount = 0 + c.completedAddrs = 0 c.entryCount = 0 } } diff --git a/migration/statehistory/ingestor.go b/migration/statehistory/ingestor.go index f58e825a8d..ffd2001583 100644 --- a/migration/statehistory/ingestor.go +++ b/migration/statehistory/ingestor.go @@ -1,57 +1,63 @@ package statehistory import ( - "github.com/NethermindEth/juno/core/felt" + "context" + "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/migration/pipeline" "github.com/NethermindEth/juno/migration/semaphore" ) -type FlushBatchFn func(t *task) - -type ingestor struct { +type baseIngestor struct { + ctx context.Context batchSemaphore semaphore.ResourceSemaphore[db.Batch] database db.KeyValueReader tasks []task - transform func(db.KeyValueReader, *task, felt.Address, FlushBatchFn) error } -var _ pipeline.State[felt.Address, task] = (*ingestor)(nil) - -func newIngestor( +// newBaseIngestor pre-allocates one batch per ingestor slot. The semaphore is +// created with capacity ingestorCount+1 immediately before this call, so the +// acquires cannot block — using GetBlocking keeps the constructor signature +// error-free. +func newBaseIngestor( + ctx context.Context, sem semaphore.ResourceSemaphore[db.Batch], database db.KeyValueReader, - transform func(db.KeyValueReader, *task, felt.Address, FlushBatchFn) error, -) *ingestor { +) baseIngestor { tasks := make([]task, ingestorCount) for i := range tasks { tasks[i] = task{batch: sem.GetBlocking()} } - return &ingestor{batchSemaphore: sem, database: database, tasks: tasks, transform: transform} -} - -func (p *ingestor) Run(index int, addr felt.Address, outputs chan<- task) error { - t := &p.tasks[index] - flush := func(t *task) { - if t.batch.Size() < targetBatchByteSize { - return - } - outputs <- *t - *t = task{batch: p.batchSemaphore.GetBlocking()} + return baseIngestor{ + ctx: ctx, + batchSemaphore: sem, + database: database, + tasks: tasks, } +} - if err := p.transform(p.database, t, addr, flush); err != nil { - return err +// flush emits the current task downstream when its batch hits target size and +// acquires a fresh batch. The ctx-aware select on the channel send is the +// snappy cancellation point. The semaphore acquire uses GetBlocking — it is +// guaranteed to unblock within one committer iteration because the committer's +// deferred Put always runs. +func (b *baseIngestor) flush(t *task, outputs chan<- task) error { + if t.batch.Size() < targetBatchByteSize { + return nil } - - if t.batch.Size() >= targetBatchByteSize { - outputs <- *t - *t = task{batch: p.batchSemaphore.GetBlocking()} + select { + case <-b.ctx.Done(): + return b.ctx.Err() + case outputs <- *t: } + *t = task{batch: b.batchSemaphore.GetBlocking()} return nil } -func (p *ingestor) Done(index int, outputs chan<- task) error { - outputs <- p.tasks[index] +func (b *baseIngestor) Done(index int, outputs chan<- task) error { + select { + case <-b.ctx.Done(): + return b.ctx.Err() + case outputs <- b.tasks[index]: + } return nil } diff --git a/migration/statehistory/migrator.go b/migration/statehistory/migrator.go index cc3eed6386..6819e97b3f 100644 --- a/migration/statehistory/migrator.go +++ b/migration/statehistory/migrator.go @@ -2,7 +2,9 @@ package statehistory import ( "context" + "errors" "fmt" + "iter" "time" "github.com/NethermindEth/juno/blockchain/networks" @@ -22,9 +24,9 @@ const ( ) type task struct { - batch db.Batch - addrCount int - entryCount int + batch db.Batch + completedAddrs int + entryCount int } var ( @@ -44,90 +46,111 @@ func (Migrator) Migrate( _ *networks.Network, logger log.StructuredLogger, ) ([]byte, error) { - addresses, err := collectAddresses(database) - if err != nil { + if err := runClassHashPhase(ctx, database, logger); err != nil { return shouldRerun, err } - if len(addresses) == 0 { - logger.Info("state history migration: Contract bucket empty, marking applied") - return shouldNotRerun, nil - } - - if err := runPhase( - ctx, - database, - logger, - addresses, - "class-hash", - TransformClassHashHistory, - ); err != nil { + if err := runNoncePhase(ctx, database, logger); err != nil { return shouldRerun, err } - if err := runPhase(ctx, database, logger, addresses, "nonce", TransformNonceHistory); err != nil { - return shouldRerun, err - } - if err := runPhase( - ctx, - database, - logger, - addresses, - "storage", - TransformStorageHistory, - ); err != nil { + if err := runStoragePhase(ctx, database, logger); err != nil { return shouldRerun, err } return shouldNotRerun, nil } -func runPhase( +func runClassHashPhase( + ctx context.Context, + database db.KeyValueStore, + logger log.StructuredLogger, +) error { + sem, src, sourceErr := setupBeforePhase(database) + ing := newClassHashIngestor(ctx, sem, database) + return runPipeline(ctx, "class-hash", src, ing, logger, sem, sourceErr) +} + +func runNoncePhase( + ctx context.Context, + database db.KeyValueStore, + logger log.StructuredLogger, +) error { + sem, src, sourceErr := setupBeforePhase(database) + ing := newNonceIngestor(ctx, sem, database) + return runPipeline(ctx, "nonce", src, ing, logger, sem, sourceErr) +} + +func runStoragePhase( ctx context.Context, database db.KeyValueStore, logger log.StructuredLogger, - addresses []felt.Address, - name string, - transform func(db.KeyValueReader, *task, felt.Address, FlushBatchFn) error, ) error { + sem, src, sourceErr := setupBeforePhase(database) + ing := newStorageIngestor(ctx, sem, database) + return runPipeline(ctx, "storage", src, ing, logger, sem, sourceErr) +} + +func setupBeforePhase( + database db.KeyValueStore, +) (semaphore.ResourceSemaphore[db.Batch], pipeline.Pipeline[*felt.Felt], func() error) { sem := semaphore.New(ingestorCount+1, func() db.Batch { return database.NewBatchWithSize(batchByteSize) }) - src := pipeline.Source(func(yield func(felt.Address) bool) { - for _, a := range addresses { - if !yield(a) { - return - } - } - }) - ingestors := pipeline.New(src, ingestorCount, newIngestor(sem, database, transform)) + seq, sourceErr := addressSeq(database) + return sem, pipeline.Source(seq), sourceErr +} + +func runPipeline( + ctx context.Context, + name string, + src pipeline.Pipeline[*felt.Felt], + ing pipeline.State[*felt.Felt, task], + logger log.StructuredLogger, + sem semaphore.ResourceSemaphore[db.Batch], + sourceErr func() error, +) error { + ingestors := pipeline.New(src, ingestorCount, ing) committers := pipeline.New(ingestors, 1, newCommitter(logger, sem, name)) _, wait := committers.Run(ctx) res := wait() - if res.Err != nil { - return res.Err + + if err := errors.Join(sourceErr(), res.Err); err != nil { + return fmt.Errorf("%s: %w", name, err) } if !res.IsDone { + if ctxErr := ctx.Err(); ctxErr != nil { + return fmt.Errorf("%s: %w", name, ctxErr) + } return fmt.Errorf("%s phase did not complete", name) } return nil } -func collectAddresses(r db.KeyValueReader) ([]felt.Address, error) { - prefix := db.Contract.Key() - it, err := r.NewIterator(prefix, true) - if err != nil { - return nil, err - } - defer it.Close() - - var addrs []felt.Address - for valid := it.First(); valid; valid = it.Next() { - key := it.Key() - if len(key) != len(prefix)+felt.Bytes { - continue +func addressSeq(r db.KeyValueReader) (iter.Seq[*felt.Felt], func() error) { + var iterErr error + seq := func(yield func(*felt.Felt) bool) { + prefix := db.Contract.Key() + it, err := r.NewIterator(prefix, true) + if err != nil { + iterErr = err + return + } + defer it.Close() + for valid := it.First(); valid; valid = it.Next() { + key := it.Key() + if len(key) != len(prefix)+felt.Bytes { + iterErr = fmt.Errorf( + "malformed Contract key: len %d, want %d", + len(key), + len(prefix)+felt.Bytes, + ) + return + } + f := felt.FromBytes[felt.Felt](key[len(prefix):]) + if !yield(&f) { + return + } } - f := felt.FromBytes[felt.Felt](key[len(prefix):]) - addrs = append(addrs, felt.Address(f)) } - return addrs, nil + return seq, func() error { return iterErr } } diff --git a/migration/statehistory/migrator_test.go b/migration/statehistory/migrator_test.go index 30d9978f94..00c21dbb50 100644 --- a/migration/statehistory/migrator_test.go +++ b/migration/statehistory/migrator_test.go @@ -23,10 +23,9 @@ func seedContract( memDB db.KeyValueStore, addr felt.Felt, nonce, classHash felt.Felt, - deployHeight uint64, ) { t.Helper() - require.NoError(t, state.WriteContract(memDB, &addr, nonce, classHash, deployHeight)) + require.NoError(t, state.WriteContract(memDB, &addr, nonce, classHash, 100)) } func seedDeprecatedClassHashHistory( @@ -62,9 +61,6 @@ func seedDeprecatedStorageHistory( require.NoError(t, core.WriteDeprecatedContractStorageHistory(w, &addr, &slot, &oldValue, block)) } -// seedDeprecatedStorageTrie populates the deprecated ContractStorage trie for -// `addr` with the given (slot -> value) leaves, so the storage phase can read -// head values. func seedDeprecatedStorageTrie( t *testing.T, memDB db.KeyValueStore, @@ -100,8 +96,6 @@ func bucketKeyCount(t *testing.T, r db.KeyValueReader, bucket db.Bucket) int { return count } -// ----- Tests ----- - func TestMigrate_EmptyDB(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) @@ -116,16 +110,13 @@ func TestMigrate_EmptyDB(t *testing.T) { require.Nil(t, res) } -// Class-hash phase: a contract that was never reclassed has no deprecated -// entries. After migration, history has exactly one entry: -// (addr, deploy_height) -> ClassHash. func TestMigrate_ClassHash_DeployOnly(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) addr := felt.FromUint64[felt.Felt](1) classHash := felt.FromUint64[felt.Felt](170) - seedContract(t, memDB, addr, felt.Zero, classHash, 100) + seedContract(t, memDB, addr, felt.Zero, classHash) res, err := Migrator{}.Migrate( context.Background(), @@ -150,10 +141,6 @@ func TestMigrate_ClassHash_DeployOnly(t *testing.T) { ) } -// Class-hash phase: a contract reclassed once. The deprecated bucket has one -// entry holding the deploy class hash (the value before the replace). After -// migration, history has: (addr, deploy_height) -> deploy class hash, and -// (addr, replace_block) -> replaced class hash. func TestMigrate_ClassHash_Reclassed(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) @@ -164,7 +151,7 @@ func TestMigrate_ClassHash_Reclassed(t *testing.T) { deployHeight := uint64(100) replaceBlock := uint64(300) - seedContract(t, memDB, addr, felt.Zero, replacedClass, deployHeight) + seedContract(t, memDB, addr, felt.Zero, replacedClass) seedDeprecatedClassHashHistory(t, memDB, addr, replaceBlock, deployClass) res, err := Migrator{}.Migrate( @@ -195,8 +182,6 @@ func TestMigrate_ClassHash_Reclassed(t *testing.T) { ) } -// Nonce phase: contract whose nonce was updated multiple times. -// Old: (B_1, 0), (B_2, n_1). New: (B_1, n_1), (B_2, head). func TestMigrate_Nonce_Updated(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) @@ -204,9 +189,8 @@ func TestMigrate_Nonce_Updated(t *testing.T) { addr := felt.FromUint64[felt.Felt](1) classHash := felt.FromUint64[felt.Felt](170) headNonce := felt.FromUint64[felt.Felt](66) - deployHeight := uint64(100) - seedContract(t, memDB, addr, headNonce, classHash, deployHeight) + seedContract(t, memDB, addr, headNonce, classHash) seedDeprecatedNonceHistory(t, memDB, addr, 200, felt.Zero) seedDeprecatedNonceHistory(t, memDB, addr, 300, felt.FromUint64[felt.Felt](1)) @@ -243,14 +227,12 @@ func TestMigrate_Nonce_Updated(t *testing.T) { ) } -// Nonce phase: deploy-only contract (nonce never updated). No deprecated -// entries. Migration is a no-op for this address; history stays empty for it. func TestMigrate_Nonce_DeployOnly(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) addr := felt.FromUint64[felt.Felt](1) - seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](170), 100) + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](170)) res, err := Migrator{}.Migrate( context.Background(), @@ -269,9 +251,6 @@ func TestMigrate_Nonce_DeployOnly(t *testing.T) { ) } -// Storage phase: a slot with multi-write history. -// Old has (B_1, 0), (B_2, firstVal), (B_3, secondVal). -// Head from old trie = headVal. New has (B_1, firstVal), (B_2, secondVal), (B_3, headVal). func TestMigrate_Storage_MultiWrite(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) @@ -282,7 +261,7 @@ func TestMigrate_Storage_MultiWrite(t *testing.T) { secondVal := felt.FromUint64[felt.Felt](12) headVal := felt.FromUint64[felt.Felt](7) - seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](187), 100) + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](187)) seedDeprecatedStorageHistory(t, memDB, addr, slot, 100, felt.Zero) seedDeprecatedStorageHistory(t, memDB, addr, slot, 200, firstVal) @@ -322,8 +301,6 @@ func TestMigrate_Storage_MultiWrite(t *testing.T) { ) } -// Storage phase: a slot with one write at deploy block. Old: (B_1, 0). Head from trie = v. -// New: (B_1, v). func TestMigrate_Storage_SingleWrite(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) @@ -332,7 +309,7 @@ func TestMigrate_Storage_SingleWrite(t *testing.T) { slot := felt.FromUint64[felt.Felt](170) v := felt.FromUint64[felt.Felt](9) - seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](187), 100) + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](187)) seedDeprecatedStorageHistory(t, memDB, addr, slot, 100, felt.Zero) seedDeprecatedStorageTrie(t, memDB, addr, map[felt.Felt]felt.Felt{slot: v}) @@ -355,8 +332,6 @@ func TestMigrate_Storage_SingleWrite(t *testing.T) { assert.Equal(t, 0, bucketKeyCount(t, memDB, db.DeprecatedContractStorageHistory)) } -// Idempotency: running the migration twice produces the same history result -// and doesn't corrupt anything. func TestMigrate_Idempotent(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) @@ -365,11 +340,11 @@ func TestMigrate_Idempotent(t *testing.T) { deployClass := felt.FromUint64[felt.Felt](170) replacedClass := felt.FromUint64[felt.Felt](187) - seedContract(t, memDB, addr, felt.FromUint64[felt.Felt](66), replacedClass, 100) + seedContract(t, memDB, addr, felt.FromUint64[felt.Felt](66), replacedClass) seedDeprecatedClassHashHistory(t, memDB, addr, 300, deployClass) seedDeprecatedNonceHistory(t, memDB, addr, 200, felt.Zero) - for i := 0; i < 3; i++ { + for range 3 { res, err := Migrator{}.Migrate( context.Background(), memDB, @@ -396,10 +371,6 @@ func TestMigrate_Idempotent(t *testing.T) { assert.Equal(t, felt.FromUint64[felt.Felt](66), got) } -// Models a crash between writing some history entries and the final -// DeleteRange of the deprecated bucket for class-hash: history has the deploy -// entry written, the deprecated bucket is still fully intact. The migration -// must reach the same final state as a clean run. func TestMigrate_ClassHash_ResumeFromPartial(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) @@ -410,11 +381,9 @@ func TestMigrate_ClassHash_ResumeFromPartial(t *testing.T) { deployHeight := uint64(100) replaceBlock := uint64(300) - seedContract(t, memDB, addr, felt.Zero, replacedClass, deployHeight) + seedContract(t, memDB, addr, felt.Zero, replacedClass) seedDeprecatedClassHashHistory(t, memDB, addr, replaceBlock, deployClass) - // Simulate a prior partial run: deploy entry already written to history, - // deprecated bucket still fully populated, DeleteRange not yet executed. require.NoError(t, state.WriteClassHashHistory(memDB, &addr, deployHeight, &deployClass)) res, err := Migrator{}.Migrate( @@ -445,9 +414,6 @@ func TestMigrate_ClassHash_ResumeFromPartial(t *testing.T) { ) } -// Pedersen storage trie omits zero-valued leaves. A slot that has deprecated -// history (was once written) but whose current value is zero has no leaf in -// the trie. The lockstep iteration must surface head=Zero for such slots. func TestMigrate_Storage_ZeroedSlotHasNoLeaf(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) @@ -457,15 +423,12 @@ func TestMigrate_Storage_ZeroedSlotHasNoLeaf(t *testing.T) { keptSlot := felt.FromUint64[felt.Felt](187) keptHead := felt.FromUint64[felt.Felt](9) - seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](204), 100) + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](204)) - // Both slots have deprecated entries... seedDeprecatedStorageHistory(t, memDB, addr, zeroedSlot, 100, felt.Zero) seedDeprecatedStorageHistory(t, memDB, addr, zeroedSlot, 200, felt.FromUint64[felt.Felt](5)) seedDeprecatedStorageHistory(t, memDB, addr, keptSlot, 100, felt.Zero) - // ...but only keptSlot has a leaf in the trie. zeroedSlot's current value - // is 0 → Pedersen trie didn't store a leaf for it. seedDeprecatedStorageTrie(t, memDB, addr, map[felt.Felt]felt.Felt{keptSlot: keptHead}) res, err := Migrator{}.Migrate( @@ -480,18 +443,14 @@ func TestMigrate_Storage_ZeroedSlotHasNoLeaf(t *testing.T) { reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) require.NoError(t, err) - // keptSlot's last entry should be the head from the trie. got, err := reader.ContractStorageAt(&addr, &keptSlot, 100) require.NoError(t, err) assert.Equal(t, keptHead, got, "kept slot last entry = head") - // zeroedSlot's last entry should be 0 (no leaf in trie). got, err = reader.ContractStorageAt(&addr, &zeroedSlot, 200) require.NoError(t, err) assert.Equal(t, felt.Zero, got, "zeroed slot last entry = Zero (no leaf in trie)") - // First entry of zeroedSlot's deprecated history was at block 100 with - // v_1=0; new layout at block 100 should be v_2 = 0x5 (next entry's value). got, err = reader.ContractStorageAt(&addr, &zeroedSlot, 100) require.NoError(t, err) assert.Equal(t, felt.FromUint64[felt.Felt](5), got, "shift-up at block 100 = next-entry's value") @@ -504,11 +463,6 @@ func TestMigrate_Storage_ZeroedSlotHasNoLeaf(t *testing.T) { ) } -// Storage at scale: many slots × many writes per slot for a single address. -// Verifies the per-slot grouping plus per-address DeleteRange handles -// non-trivial volumes correctly. Not large enough to hit the 96MB batch -// threshold (memory test runtimes would be silly), but exercises the -// grouping + cleanup logic across hundreds of slot boundaries. func TestMigrate_Storage_ManyEntries(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) @@ -516,7 +470,7 @@ func TestMigrate_Storage_ManyEntries(t *testing.T) { addr := felt.FromUint64[felt.Felt](1) headValues := map[felt.Felt]felt.Felt{} - seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](187), 100) + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](187)) const ( numSlots = 50 @@ -529,7 +483,7 @@ func TestMigrate_Storage_ManyEntries(t *testing.T) { headVal := felt.NewFromUint64[felt.Felt](1000000 + s) headValues[*slot] = *headVal - for b := uint64(0); b < numEntriesPerSlot; b++ { + for b := range uint64(numEntriesPerSlot) { block := startBlock + b oldVal := felt.Zero if b > 0 { @@ -559,8 +513,6 @@ func TestMigrate_Storage_ManyEntries(t *testing.T) { reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) require.NoError(t, err) - // Sample a few slots: the last block per slot must equal head; intermediate - // blocks must equal the next deprecated entry's old value. for slot, head := range headValues { lastBlock := startBlock + numEntriesPerSlot - 1 got, err := reader.ContractStorageAt(&addr, &slot, lastBlock) @@ -568,3 +520,242 @@ func TestMigrate_Storage_ManyEntries(t *testing.T) { assert.Equalf(t, head, got, "last entry must equal head for slot %v", &slot) } } + +func TestMigrate_Storage_MultiAddress(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addrs := []felt.Felt{ + felt.FromUint64[felt.Felt](1), + felt.FromUint64[felt.Felt](2), + felt.FromUint64[felt.Felt](3), + } + slots := []felt.Felt{ + felt.FromUint64[felt.Felt](100), + felt.FromUint64[felt.Felt](200), + } + + for i := range addrs { + seedContract(t, memDB, addrs[i], felt.Zero, felt.FromUint64[felt.Felt](uint64(170+i))) + for _, slot := range slots { + seedDeprecatedStorageHistory(t, memDB, addrs[i], slot, 100, felt.Zero) + seedDeprecatedStorageHistory( + t, memDB, addrs[i], slot, 200, + felt.FromUint64[felt.Felt](uint64(10+i)), + ) + } + seedDeprecatedStorageTrie(t, memDB, addrs[i], map[felt.Felt]felt.Felt{ + slots[0]: felt.FromUint64[felt.Felt](uint64(1000 + i*10)), + slots[1]: felt.FromUint64[felt.Felt](uint64(1000 + i*10 + 1)), + }) + } + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + for i := range addrs { + for j, slot := range slots { + got, err := reader.ContractStorageAt(&addrs[i], &slot, 100) + require.NoErrorf(t, err, "addr %d slot %d block 100", i, j) + assert.Equalf( + t, felt.FromUint64[felt.Felt](uint64(10+i)), got, + "addr %d slot %d block 100 = next-entry's value", i, j, + ) + got, err = reader.ContractStorageAt(&addrs[i], &slot, 200) + require.NoErrorf(t, err, "addr %d slot %d block 200", i, j) + assert.Equalf( + t, felt.FromUint64[felt.Felt](uint64(1000+i*10+j)), got, + "addr %d slot %d block 200 = head from trie", i, j, + ) + } + } + + assert.Zero( + t, + bucketKeyCount(t, memDB, db.DeprecatedContractStorageHistory), + "deprecated must be empty across all addresses", + ) +} + +func TestMigrate_CancelledContext_ResumesCleanly(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](170)) + seedDeprecatedClassHashHistory(t, memDB, addr, 200, felt.FromUint64[felt.Felt](42)) + seedDeprecatedNonceHistory(t, memDB, addr, 200, felt.Zero) + slot := felt.FromUint64[felt.Felt](5) + seedDeprecatedStorageHistory(t, memDB, addr, slot, 200, felt.Zero) + seedDeprecatedStorageTrie(t, memDB, addr, map[felt.Felt]felt.Felt{ + slot: felt.FromUint64[felt.Felt](9), + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + res, err := Migrator{}.Migrate(ctx, memDB, &networks.Sepolia, log.NewNopZapLogger()) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + require.NotNil(t, res, "shouldRerun sentinel must not be nil") + require.Empty(t, res, "shouldRerun is a non-nil empty slice") + + res, err = Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + assert.Zero(t, bucketKeyCount(t, memDB, db.DeprecatedContractClassHashHistory)) + assert.Zero(t, bucketKeyCount(t, memDB, db.DeprecatedContractNonceHistory)) + assert.Zero(t, bucketKeyCount(t, memDB, db.DeprecatedContractStorageHistory)) +} + +func TestMigrate_Storage_ResumeFromPartial(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + slot := felt.FromUint64[felt.Felt](170) + firstVal := felt.FromUint64[felt.Felt](5) + secondVal := felt.FromUint64[felt.Felt](12) + headVal := felt.FromUint64[felt.Felt](7) + + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](187)) + + seedDeprecatedStorageHistory(t, memDB, addr, slot, 100, felt.Zero) + seedDeprecatedStorageHistory(t, memDB, addr, slot, 200, firstVal) + seedDeprecatedStorageHistory(t, memDB, addr, slot, 300, secondVal) + seedDeprecatedStorageTrie(t, memDB, addr, map[felt.Felt]felt.Felt{slot: headVal}) + + require.NoError(t, state.WriteStorageHistory(memDB, &addr, &slot, 100, &firstVal)) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + got, err := reader.ContractStorageAt(&addr, &slot, 100) + require.NoError(t, err) + assert.Equal(t, firstVal, got, "preserved deploy entry after partial-resume") + + got, err = reader.ContractStorageAt(&addr, &slot, 200) + require.NoError(t, err) + assert.Equal(t, secondVal, got) + + got, err = reader.ContractStorageAt(&addr, &slot, 300) + require.NoError(t, err) + assert.Equal(t, headVal, got) + + assert.Zero(t, bucketKeyCount(t, memDB, db.DeprecatedContractStorageHistory)) +} + +func TestMigrate_AddressWithEmptyHistoryForOnePhase(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + classHash := felt.FromUint64[felt.Felt](170) + deployClass := felt.FromUint64[felt.Felt](42) + + seedContract(t, memDB, addr, felt.Zero, classHash) + seedDeprecatedClassHashHistory(t, memDB, addr, 300, deployClass) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + got, err := reader.ContractClassHashAt(&addr, 100) + require.NoError(t, err) + assert.Equal(t, deployClass, got) + + got, err = reader.ContractClassHashAt(&addr, 300) + require.NoError(t, err) + assert.Equal(t, classHash, got) + + assert.Zero( + t, bucketKeyCount(t, memDB, db.ContractNonceHistory), + "nonce history empty when phase no-ops", + ) + assert.Zero( + t, bucketKeyCount(t, memDB, db.ContractStorageHistory), + "storage history empty when phase no-ops", + ) + assert.Zero(t, bucketKeyCount(t, memDB, db.DeprecatedContractClassHashHistory)) +} + +func TestMigrate_Storage_InterleavedZeroedSlots(t *testing.T) { + memDB := memory.New() + t.Cleanup(func() { memDB.Close() }) + + addr := felt.FromUint64[felt.Felt](1) + slot1 := felt.FromUint64[felt.Felt](100) // zeroed + slot2 := felt.FromUint64[felt.Felt](200) // kept + slot3 := felt.FromUint64[felt.Felt](300) // zeroed + slot4 := felt.FromUint64[felt.Felt](400) // kept + head2 := felt.FromUint64[felt.Felt](22) + head4 := felt.FromUint64[felt.Felt](44) + + seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](204)) + + for _, slot := range []felt.Felt{slot1, slot2, slot3, slot4} { + seedDeprecatedStorageHistory(t, memDB, addr, slot, 100, felt.Zero) + } + seedDeprecatedStorageTrie(t, memDB, addr, map[felt.Felt]felt.Felt{ + slot2: head2, + slot4: head4, + }) + + res, err := Migrator{}.Migrate( + context.Background(), + memDB, + &networks.Sepolia, + log.NewNopZapLogger(), + ) + require.NoError(t, err) + require.Nil(t, res) + + reader, err := state.NewStateReader(&felt.Zero, state.NewStateDB(memDB, triedb.New(memDB, nil))) + require.NoError(t, err) + + for _, slot := range []felt.Felt{slot1, slot3} { + got, err := reader.ContractStorageAt(&addr, &slot, 100) + require.NoErrorf(t, err, "zeroed slot %v", &slot) + assert.Equalf(t, felt.Zero, got, "zeroed slot %v has no trie leaf", &slot) + } + got, err := reader.ContractStorageAt(&addr, &slot2, 100) + require.NoError(t, err) + assert.Equal(t, head2, got, "slot2 kept its head value") + got, err = reader.ContractStorageAt(&addr, &slot4, 100) + require.NoError(t, err) + assert.Equal(t, head4, got, "slot4 kept its head value") + + assert.Zero(t, bucketKeyCount(t, memDB, db.DeprecatedContractStorageHistory)) +} diff --git a/migration/statehistory/nonce_ingestor.go b/migration/statehistory/nonce_ingestor.go new file mode 100644 index 0000000000..32f9ef7d71 --- /dev/null +++ b/migration/statehistory/nonce_ingestor.go @@ -0,0 +1,79 @@ +package statehistory + +import ( + "context" + "fmt" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/dbutils" + "github.com/NethermindEth/juno/migration/pipeline" + "github.com/NethermindEth/juno/migration/semaphore" +) + +type nonceIngestor struct { + baseIngestor +} + +var _ pipeline.State[*felt.Felt, task] = (*nonceIngestor)(nil) + +func newNonceIngestor( + ctx context.Context, + sem semaphore.ResourceSemaphore[db.Batch], + database db.KeyValueReader, +) *nonceIngestor { + return &nonceIngestor{baseIngestor: newBaseIngestor(ctx, sem, database)} +} + +func (i *nonceIngestor) Run(index int, addr *felt.Felt, outputs chan<- task) error { + t := &i.tasks[index] + deprecatedPrefix := db.DeprecatedContractNonceHistoryKey(addr) + + depIt, err := i.database.NewIterator(deprecatedPrefix, true) + if err != nil { + return fmt.Errorf("nonce: open deprecated iter(%s): %w", addr, err) + } + defer depIt.Close() + if !depIt.First() { + return nil + } + + contract, err := state.GetContract(i.database, addr) + if err != nil { + return fmt.Errorf("nonce: GetContract(%s): %w", addr, err) + } + + for { + block, err := parseBlockKey(depIt.Key(), deprecatedPrefix) + if err != nil { + return fmt.Errorf("nonce(%s): %w", addr, err) + } + hasNext := depIt.Next() + historyValue := contract.Nonce + if hasNext { + rawValue, err := depIt.Value() + if err != nil { + return fmt.Errorf("nonce(%s): %w", addr, err) + } + historyValue = felt.FromBytes[felt.Felt](rawValue) + } + err = state.WriteNonceHistory(t.batch, addr, block, &historyValue) + if err != nil { + return err + } + t.entryCount++ + if err := i.flush(t, outputs); err != nil { + return err + } + if !hasNext { + break + } + } + + if err := t.batch.DeleteRange(deprecatedPrefix, dbutils.UpperBound(deprecatedPrefix)); err != nil { + return fmt.Errorf("nonce: DeleteRange deprecated(%s): %w", addr, err) + } + t.completedAddrs++ + return nil +} diff --git a/migration/statehistory/storage_ingestor.go b/migration/statehistory/storage_ingestor.go new file mode 100644 index 0000000000..a2270e319c --- /dev/null +++ b/migration/statehistory/storage_ingestor.go @@ -0,0 +1,175 @@ +package statehistory + +import ( + "bytes" + "context" + "fmt" + + "github.com/NethermindEth/juno/core/deprecatedstate" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/dbutils" + "github.com/NethermindEth/juno/migration/pipeline" + "github.com/NethermindEth/juno/migration/semaphore" +) + +type storageIngestor struct { + baseIngestor +} + +var _ pipeline.State[*felt.Felt, task] = (*storageIngestor)(nil) + +func newStorageIngestor( + ctx context.Context, + sem semaphore.ResourceSemaphore[db.Batch], + database db.KeyValueReader, +) *storageIngestor { + return &storageIngestor{baseIngestor: newBaseIngestor(ctx, sem, database)} +} + +func (i *storageIngestor) Run(index int, addr *felt.Felt, outputs chan<- task) error { + t := &i.tasks[index] + + addrBytes := addr.Marshal() + deprecatedPrefix := db.DeprecatedContractStorageHistory.Key(addrBytes) + + deprecatedHistoryIt, err := i.database.NewIterator(deprecatedPrefix, true) + if err != nil { + return fmt.Errorf("storage: open deprecated iter(%s): %w", addr, err) + } + defer deprecatedHistoryIt.Close() + if !deprecatedHistoryIt.First() { + return nil + } + + leafPrefix := db.ContractStorage.Key(addrBytes) + leafPrefix = append(leafPrefix, deprecatedstate.ContractStorageTrieHeight) + + headStorageTrieIt, err := i.database.NewIterator(leafPrefix, true) + if err != nil { + return fmt.Errorf("storage: open leaf iter(%s): %w", addr, err) + } + defer headStorageTrieIt.Close() + leafValid := headStorageTrieIt.First() + + for { + slot, block, err := parseStorageKey(deprecatedHistoryIt.Key(), deprecatedPrefix) + if err != nil { + return fmt.Errorf("storage: parse key(%s): %w", addr, err) + } + + successorSlot, successorValue, hasNext, err := peekSuccessor( + deprecatedHistoryIt, + deprecatedPrefix, + addr, + ) + if err != nil { + return err + } + + historyValue, advanced, err := resolveHistoryValue( + headStorageTrieIt, + leafPrefix, + addr, + &slot, + hasNext, + successorSlot, + successorValue, + leafValid, + ) + if err != nil { + return err + } + if advanced { + leafValid = headStorageTrieIt.Next() + } + + err = state.WriteStorageHistory( + t.batch, + addr, + &slot, + block, + &historyValue, + ) + if err != nil { + return err + } + t.entryCount++ + if err := i.flush(t, outputs); err != nil { + return err + } + + if !hasNext { + break + } + } + + if err := t.batch.DeleteRange(deprecatedPrefix, dbutils.UpperBound(deprecatedPrefix)); err != nil { + return fmt.Errorf("storage: DeleteRange deprecated(%s): %w", addr, err) + } + t.completedAddrs++ + return nil +} + +// peekSuccessor advances the deprecated-history iterator. If a next entry +// exists, returns its (slot, value, true); otherwise returns (_, _, false). +func peekSuccessor( + it db.Iterator, + prefix []byte, + addr *felt.Felt, +) (felt.Felt, felt.Felt, bool, error) { + if !it.Next() { + return felt.Felt{}, felt.Felt{}, false, nil + } + slot, _, err := parseStorageKey(it.Key(), prefix) + if err != nil { + return felt.Felt{}, felt.Felt{}, false, fmt.Errorf( + "storage: parse successor key(%s): %w", addr, err, + ) + } + rawValue, err := it.Value() + if err != nil { + return felt.Felt{}, felt.Felt{}, false, fmt.Errorf( + "storage: read successor value(%s, slot=%s): %w", addr, &slot, err, + ) + } + return slot, felt.FromBytes[felt.Felt](rawValue), true, nil +} + +// resolveHistoryValue decides the value to install at the current entry: the +// successor's value when both are on the same slot, otherwise the head-trie +// leaf (when the iterator is positioned on this slot), otherwise zero. Returns +// advanced=true when the head-trie iterator should be advanced by the caller. +func resolveHistoryValue( + headIt db.Iterator, + leafPrefix []byte, + addr, slot *felt.Felt, + hasSuccessor bool, + successorSlot, successorValue felt.Felt, + leafValid bool, +) (value felt.Felt, advanced bool, err error) { + if hasSuccessor && successorSlot == *slot { + return successorValue, false, nil + } + if !leafValid { + return felt.Felt{}, false, nil + } + if !bytes.Equal(headIt.Key()[len(leafPrefix):], slot.Marshal()) { + return felt.Felt{}, false, nil + } + raw, err := headIt.Value() + if err != nil { + return felt.Felt{}, false, fmt.Errorf( + "storage: leaf(%s, slot=%s): %w", addr, slot, err, + ) + } + var node trie.Node + if err := node.UnmarshalBinary(raw); err != nil { + return felt.Felt{}, false, fmt.Errorf( + "storage: decode leaf(%s, slot=%s): %w", addr, slot, err, + ) + } + return *node.Value, true, nil +} diff --git a/migration/statehistory/transform.go b/migration/statehistory/transform.go deleted file mode 100644 index b8352d31a4..0000000000 --- a/migration/statehistory/transform.go +++ /dev/null @@ -1,259 +0,0 @@ -package statehistory - -import ( - "bytes" - "fmt" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/trie" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/db/dbutils" -) - -func TransformClassHashHistory( - database db.KeyValueReader, - t *task, - addr felt.Address, - flush FlushBatchFn, -) error { - addrFelt := felt.Felt(addr) - deprecatedPrefix := db.DeprecatedContractClassHashHistoryKey(&addrFelt) - - contract, err := state.GetContract(database, &addrFelt) - if err != nil { - return fmt.Errorf("class-hash: GetContract(%s): %w", &addrFelt, err) - } - - deployKey := db.ContractClassHashHistoryAtBlockKey(&addrFelt, contract.DeployedHeight) - deployEntryExists, err := database.Has(deployKey) - if err != nil { - return fmt.Errorf("class-hash: Has(deploy entry): %w", err) - } - - depIt, err := database.NewIterator(deprecatedPrefix, true) - if err != nil { - return fmt.Errorf("class-hash: open deprecated iter(%s): %w", &addrFelt, err) - } - defer depIt.Close() - - if !depIt.First() { - if deployEntryExists { - return nil - } - if err := state.WriteClassHashHistory( - t.batch, - &addrFelt, - contract.DeployedHeight, - &contract.ClassHash, - ); err != nil { - return err - } - t.addrCount++ - t.entryCount++ - flush(t) - return nil - } - - rawValue, err := depIt.Value() - if err != nil { - return fmt.Errorf("class-hash: read first value(%s): %w", &addrFelt, err) - } - deployClassHash := felt.FromBytes[felt.Felt](rawValue) - if err := state.WriteClassHashHistory( - t.batch, - &addrFelt, - contract.DeployedHeight, - &deployClassHash, - ); err != nil { - return err - } - t.entryCount++ - flush(t) - - if _, err := shiftUpHistoryEntries( - depIt, - deprecatedPrefix, - &addrFelt, - contract.ClassHash, - t, - flush, - state.WriteClassHashHistory, - ); err != nil { - return fmt.Errorf("class-hash(%s): %w", &addrFelt, err) - } - - if err := t.batch.DeleteRange(deprecatedPrefix, dbutils.UpperBound(deprecatedPrefix)); err != nil { - return fmt.Errorf("class-hash: DeleteRange deprecated(%s): %w", &addrFelt, err) - } - t.addrCount++ - - return nil -} - -func TransformNonceHistory( - database db.KeyValueReader, - t *task, - addr felt.Address, - flush FlushBatchFn, -) error { - addrFelt := felt.Felt(addr) - deprecatedPrefix := db.DeprecatedContractNonceHistoryKey(&addrFelt) - - depIt, err := database.NewIterator(deprecatedPrefix, true) - if err != nil { - return fmt.Errorf("nonce: open deprecated iter(%s): %w", &addrFelt, err) - } - defer depIt.Close() - if !depIt.First() { - return nil - } - - contract, err := state.GetContract(database, &addrFelt) - if err != nil { - return fmt.Errorf("nonce: GetContract(%s): %w", &addrFelt, err) - } - - if _, err := shiftUpHistoryEntries( - depIt, - deprecatedPrefix, - &addrFelt, - contract.Nonce, - t, - flush, - state.WriteNonceHistory, - ); err != nil { - return fmt.Errorf("nonce(%s): %w", &addrFelt, err) - } - - if err := t.batch.DeleteRange(deprecatedPrefix, dbutils.UpperBound(deprecatedPrefix)); err != nil { - return fmt.Errorf("nonce: DeleteRange deprecated(%s): %w", &addrFelt, err) - } - t.addrCount++ - - return nil -} - -func TransformStorageHistory( - database db.KeyValueReader, - t *task, - addr felt.Address, - flush FlushBatchFn) error { - addrFelt := felt.Felt(addr) - addrBytes := addrFelt.Marshal() - deprecatedPrefix := db.DeprecatedContractStorageHistory.Key(addrBytes) - - deprecatedHistoryIt, err := database.NewIterator(deprecatedPrefix, true) - if err != nil { - return fmt.Errorf("storage: open deprecated iter(%s): %w", &addrFelt, err) - } - defer deprecatedHistoryIt.Close() - if !deprecatedHistoryIt.First() { - return nil - } - - leafPrefix := db.ContractStorage.Key(addrBytes) - leafPrefix = append(leafPrefix, 251) - - headStorageTrieIt, err := database.NewIterator(leafPrefix, true) - if err != nil { - return fmt.Errorf("storage: open leaf iter(%s): %w", &addrFelt, err) - } - defer headStorageTrieIt.Close() - leafValid := headStorageTrieIt.First() - - for { - slot, block, err := parseStorageKey(deprecatedHistoryIt.Key(), deprecatedPrefix) - if err != nil { - return fmt.Errorf("storage: parse key(%s): %w", &addrFelt, err) - } - - hasNext := deprecatedHistoryIt.Next() - var successorSlot, successorValue felt.Felt - if hasNext { - s, _, err := parseStorageKey(deprecatedHistoryIt.Key(), deprecatedPrefix) - if err != nil { - return fmt.Errorf("storage: parse successor key(%s): %w", &addrFelt, err) - } - successorSlot = s - rawValue, err := deprecatedHistoryIt.Value() - if err != nil { - return fmt.Errorf("storage: read successor value(%s, slot=%s): %w", &addrFelt, &s, err) - } - successorValue = felt.FromBytes[felt.Felt](rawValue) - } - - var historyValue felt.Felt - switch { - case hasNext && successorSlot == slot: - historyValue = successorValue - case leafValid: - leafSlot := headStorageTrieIt.Key()[len(leafPrefix):] - if bytes.Equal(leafSlot, slot.Marshal()) { - raw, err := headStorageTrieIt.Value() - if err != nil { - return fmt.Errorf("storage: leaf(%s, slot=%s): %w", &addrFelt, &slot, err) - } - var node trie.Node - if err := node.UnmarshalBinary(raw); err != nil { - return fmt.Errorf("storage: decode leaf(%s, slot=%s): %w", &addrFelt, &slot, err) - } - historyValue = *node.Value - leafValid = headStorageTrieIt.Next() - } - } - - if err := state.WriteStorageHistory(t.batch, &addrFelt, &slot, block, &historyValue); err != nil { - return err - } - t.entryCount++ - flush(t) - - if !hasNext { - break - } - } - - if err := t.batch.DeleteRange(deprecatedPrefix, dbutils.UpperBound(deprecatedPrefix)); err != nil { - return fmt.Errorf("storage: DeleteRange deprecated(%s): %w", &addrFelt, err) - } - t.addrCount++ - - return nil -} - -func shiftUpHistoryEntries( - depIt db.Iterator, - deprecatedPrefix []byte, - addr *felt.Felt, - headValue felt.Felt, - t *task, - flush FlushBatchFn, - write func(db.KeyValueWriter, *felt.Felt, uint64, *felt.Felt) error, -) (int, error) { - entryCount := 0 - for { - block, err := parseBlockKey(depIt.Key(), deprecatedPrefix) - if err != nil { - return 0, err - } - hasNext := depIt.Next() - historyValue := headValue - if hasNext { - rawValue, err := depIt.Value() - if err != nil { - return 0, err - } - historyValue = felt.FromBytes[felt.Felt](rawValue) - } - if err := write(t.batch, addr, block, &historyValue); err != nil { - return 0, err - } - entryCount++ - t.entryCount++ - flush(t) - if !hasNext { - return entryCount, nil - } - } -} From 46c5aefe01012827b10359c54ebd26b19297e791 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 20 May 2026 00:35:54 +0200 Subject: [PATCH 05/14] chore: self review --- migration/statehistory/class_hash_ingestor.go | 108 ++++++++++++------ migration/statehistory/migrator.go | 27 +++++ migration/statehistory/nonce_ingestor.go | 20 ++++ migration/statehistory/storage_ingestor.go | 44 +++++++ 4 files changed, 166 insertions(+), 33 deletions(-) diff --git a/migration/statehistory/class_hash_ingestor.go b/migration/statehistory/class_hash_ingestor.go index 0aea360b29..2c810fbf8d 100644 --- a/migration/statehistory/class_hash_ingestor.go +++ b/migration/statehistory/class_hash_ingestor.go @@ -26,6 +26,31 @@ func newClassHashIngestor( return &classHashIngestor{baseIngestor: newBaseIngestor(ctx, sem, database)} } +// Run migrates the class-hash history of a single contract. +// +// Legend: Bₙ = block at which the n-th class-hash *replacement* happened. +// Vₙ = the class hash active *after* Bₙ; V₀ is the deploy-time hash. The +// deprecated layout writes nothing at deploy: each entry is written only +// on a *Replace*, and the value stored is the hash that was active before +// that replace. So deprecated[B₁] = V₀ even though no replace happened at +// deploy_h itself. The new layout adds an explicit deploy entry and shifts +// everything else by one slot: +// +// block │ deprecated │ new +// ─────────┼────────────────┼────── +// deploy_h │ — │ V₀ ← inserted from first deprecated entry +// B₁ │ V₀ │ V₁ +// B₂ │ V₁ │ V₂ +// B₃ │ V₂ │ V₃ +// ─────────┼────────────────┼────── +// > B₃ │ contract │ V₃ (last entry — self-contained) +// .ClassHash ← deprecated must reach into the Contract +// record for any block past the last replace +// +// If the deprecated history is empty (no replaces ever), the single deploy +// entry is written with contract.ClassHash directly. Deprecated rows are +// deleted at the end of the run. Resume-safe: empty-deprecated + existing +// deploy entry → no-op. func (i *classHashIngestor) Run(index int, addr *felt.Felt, outputs chan<- task) error { t := &i.tasks[index] @@ -35,12 +60,6 @@ func (i *classHashIngestor) Run(index int, addr *felt.Felt, outputs chan<- task) return fmt.Errorf("class-hash: GetContract(%s): %w", addr, err) } - deployKey := db.ContractClassHashHistoryAtBlockKey(addr, contract.DeployedHeight) - deployEntryExists, err := i.database.Has(deployKey) - if err != nil { - return fmt.Errorf("class-hash: Has(deploy entry): %w", err) - } - depIt, err := i.database.NewIterator(deprecatedPrefix, true) if err != nil { return fmt.Errorf("class-hash: open deprecated iter(%s): %w", addr, err) @@ -48,34 +67,60 @@ func (i *classHashIngestor) Run(index int, addr *felt.Felt, outputs chan<- task) defer depIt.Close() if !depIt.First() { - if deployEntryExists { - return nil - } - err = state.WriteClassHashHistory( - t.batch, - addr, - contract.DeployedHeight, - &contract.ClassHash, - ) - if err != nil { - return err - } - t.completedAddrs++ - t.entryCount++ - return i.flush(t, outputs) + return i.writeDeployOnly(t, outputs, addr, contract.DeployedHeight, &contract.ClassHash) + } + return i.writeShiftedHistory( + t, outputs, depIt, deprecatedPrefix, addr, + contract.DeployedHeight, &contract.ClassHash, + ) +} + +// writeDeployOnly handles the "no deprecated history" branch: write the +// deploy-time entry from contract.ClassHash, unless a previous run already +// wrote it. +func (i *classHashIngestor) writeDeployOnly( + t *task, + outputs chan<- task, + addr *felt.Felt, + deployHeight uint64, + classHash *felt.Felt, +) error { + deployKey := db.ContractClassHashHistoryAtBlockKey(addr, deployHeight) + deployEntryExists, err := i.database.Has(deployKey) + if err != nil { + return fmt.Errorf("class-hash: Has(deploy entry): %w", err) + } + if deployEntryExists { + return nil } + if err := state.WriteClassHashHistory(t.batch, addr, deployHeight, classHash); err != nil { + return err + } + t.completedAddrs++ + t.entryCount++ + return i.flush(t, outputs) +} +// writeShiftedHistory handles the "non-empty deprecated history" branch: +// writes the deploy entry from the first deprecated value, shifts each +// deprecated entry into the new layout using the next entry's pre-value +// (or contract.ClassHash for the last), and deletes the deprecated rows. +// depIt must be positioned at the first deprecated entry. +func (i *classHashIngestor) writeShiftedHistory( + t *task, + outputs chan<- task, + depIt db.Iterator, + prefix []byte, + addr *felt.Felt, + deployHeight uint64, + headClassHash *felt.Felt, +) error { rawValue, err := depIt.Value() if err != nil { return fmt.Errorf("class-hash: read first value(%s): %w", addr, err) } deployClassHash := felt.FromBytes[felt.Felt](rawValue) - if err := state.WriteClassHashHistory( - t.batch, - addr, - contract.DeployedHeight, - &deployClassHash, - ); err != nil { + if err := state.WriteClassHashHistory(t.batch, addr, deployHeight, &deployClassHash); err != nil { return err } t.entryCount++ @@ -83,16 +128,13 @@ func (i *classHashIngestor) Run(index int, addr *felt.Felt, outputs chan<- task) return err } - // Shift-up loop: each block in the deprecated history gets the *next* - // entry's value (since in the old layout the value at block B was the - // value before B's write). The final block gets the head class hash. for { - block, err := parseBlockKey(depIt.Key(), deprecatedPrefix) + block, err := parseBlockKey(depIt.Key(), prefix) if err != nil { return fmt.Errorf("class-hash(%s): %w", addr, err) } hasNext := depIt.Next() - historyValue := contract.ClassHash + historyValue := *headClassHash if hasNext { rawValue, err := depIt.Value() if err != nil { @@ -112,7 +154,7 @@ func (i *classHashIngestor) Run(index int, addr *felt.Felt, outputs chan<- task) } } - if err := t.batch.DeleteRange(deprecatedPrefix, dbutils.UpperBound(deprecatedPrefix)); err != nil { + if err := t.batch.DeleteRange(prefix, dbutils.UpperBound(prefix)); err != nil { return fmt.Errorf("class-hash: DeleteRange deprecated(%s): %w", addr, err) } t.completedAddrs++ diff --git a/migration/statehistory/migrator.go b/migration/statehistory/migrator.go index 6819e97b3f..96d2b56eb6 100644 --- a/migration/statehistory/migrator.go +++ b/migration/statehistory/migrator.go @@ -36,6 +36,33 @@ var ( var _ migration.Migration = (*Migrator)(nil) +// Migrator rewrites the contract history layout so each entry stores the +// post-update value at its block, instead of the pre-update value. +// +// Example — a contract whose class hash was 0xAA at deploy (block 100), +// changed to 0xBB at block 200, then to 0xCC at block 500: +// +// block │ old layout (pre-value) │ new layout (post-value) +// ──────┼────────────────────────┼───────────────────────── +// 100 │ (no entry) │ 0xAA ← explicit deploy +// 200 │ 0xAA │ 0xBB +// 500 │ 0xBB │ 0xCC +// head │ 0xCC (contract record) │ (read from history) +// +// The same shape change applies to nonces and per-slot storage. The +// migrator runs three phases (class-hash, nonce, storage); each phase +// iterates the Contract bucket and rewrites one contract's deprecated +// entries at a time, deleting them in the same batch. +// +// Crash / cancellation safety: pebble batches commit atomically, so the +// writes inside any single committed batch are durable as a unit. A +// contract whose history is large may span more than one batch — but each +// new entry's value is a pure function of the deprecated source data, so +// re-running over an already-partially-rewritten contract overwrites with +// identical values and then deletes the (still-present) deprecated rows. +// Contracts whose deprecated entries are already gone short-circuit on an +// empty iterator. The three phases run sequentially: a later phase only +// starts after the earlier phase completes. type Migrator struct{} func (Migrator) Before([]byte) error { return nil } diff --git a/migration/statehistory/nonce_ingestor.go b/migration/statehistory/nonce_ingestor.go index 32f9ef7d71..9e20210a56 100644 --- a/migration/statehistory/nonce_ingestor.go +++ b/migration/statehistory/nonce_ingestor.go @@ -26,6 +26,26 @@ func newNonceIngestor( return &nonceIngestor{baseIngestor: newBaseIngestor(ctx, sem, database)} } +// Run migrates the nonce history of a single contract. +// +// Legend: Bₙ = block at which the n-th nonce change happened. Nₙ = the +// nonce active *after* Bₙ; the deploy nonce is always 0 and is *not* +// written to the deprecated history — its presence is implicit in the +// pre-value of the first change entry. The new layout stores the same +// number of entries, just shifted to post-values: +// +// block │ deprecated │ new +// ───────┼────────────────┼────── +// B₁ │ 0 │ N₁ +// B₂ │ N₁ │ N₂ +// B₃ │ N₂ │ N₃ +// ───────┼────────────────┼────── +// > B₃ │ contract │ N₃ (last entry — self-contained) +// .Nonce ← deprecated must reach into the Contract +// record for any block past the last change +// +// Contracts with no deprecated nonce history are skipped. Deprecated rows +// are deleted at the end of the run. func (i *nonceIngestor) Run(index int, addr *felt.Felt, outputs chan<- task) error { t := &i.tasks[index] deprecatedPrefix := db.DeprecatedContractNonceHistoryKey(addr) diff --git a/migration/statehistory/storage_ingestor.go b/migration/statehistory/storage_ingestor.go index a2270e319c..e8a29cf91f 100644 --- a/migration/statehistory/storage_ingestor.go +++ b/migration/statehistory/storage_ingestor.go @@ -29,6 +29,50 @@ func newStorageIngestor( return &storageIngestor{baseIngestor: newBaseIngestor(ctx, sem, database)} } +// Run migrates the per-slot storage history of a single contract. +// +// Legend: Bₙ = block at which the n-th change to a slot happened. preXₙ +// is the value of slot X before Bₙ (= what the deprecated layout stores +// at [X, Bₙ]); headX is the slot's current value, read from the head +// storage trie. The deprecated layout writes nothing at deploy — the +// pre-deploy value (0) is implicit in the first change entry. The new +// layout stores the same number of entries per slot, just shifted to +// post-values. For one slot: +// +// block │ deprecated[slotA] │ new[slotA] +// ───────┼───────────────────┼─────────── +// B₁ │ 0 │ preA₁ +// B₂ │ preA₁ │ preA₂ +// B₃ │ preA₂ │ headA +// ───────┼───────────────────┼─────────── +// > B₃ │ head trie leaf │ headA (last entry — self-contained) +// for slotA ← deprecated must reach into the head +// storage trie for any block past the +// last change +// +// For each deprecated entry the post-value comes from one of: +// +// - the *next* deprecated entry, when it's on the same slot — its stored +// pre-value is exactly this block's post-value; +// - the head storage trie leaf for that slot, when there is no next +// deprecated entry on the same slot; +// - felt.Zero, when there is no head leaf for the slot (the slot was +// eventually zeroed out and dropped from the trie). +// +// Both the deprecated history and the head trie are sorted by raw slot +// bytes, so the ingestor walks them in lockstep — the head-trie iterator +// advances only when its current leaf matches the slot just resolved: +// +// deprecated history head trie new history +// ───────────────────── ───────────── ───────────────────────── +// [slotA, B₁..B₃] ──→ [slotA] = headA [slotA, B₁..B₃] last uses headA +// [slotB, B₁..B₂] ──→ (no leaf) [slotB, B₁..B₂] last uses 0 +// ← slotB was set (slotB was zeroed +// and later zeroed at B₂) +// [slotC, B₁] ──→ [slotC] = headC [slotC, B₁] = headC +// +// Contracts with no deprecated storage history are skipped; deprecated +// rows are deleted at the end of the run via DeleteRange. func (i *storageIngestor) Run(index int, addr *felt.Felt, outputs chan<- task) error { t := &i.tasks[index] From 05dd000e874dee9cdc34709376b3dc64014b6891 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 20 May 2026 14:50:58 +0200 Subject: [PATCH 06/14] test: test only public API --- migration/statehistory/migrator_test.go | 37 +++++++++++++------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/migration/statehistory/migrator_test.go b/migration/statehistory/migrator_test.go index 00c21dbb50..afb348f007 100644 --- a/migration/statehistory/migrator_test.go +++ b/migration/statehistory/migrator_test.go @@ -1,4 +1,4 @@ -package statehistory +package statehistory_test import ( "context" @@ -13,6 +13,7 @@ import ( "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" + "github.com/NethermindEth/juno/migration/statehistory" "github.com/NethermindEth/juno/utils/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -100,7 +101,7 @@ func TestMigrate_EmptyDB(t *testing.T) { memDB := memory.New() t.Cleanup(func() { memDB.Close() }) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -118,7 +119,7 @@ func TestMigrate_ClassHash_DeployOnly(t *testing.T) { classHash := felt.FromUint64[felt.Felt](170) seedContract(t, memDB, addr, felt.Zero, classHash) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -154,7 +155,7 @@ func TestMigrate_ClassHash_Reclassed(t *testing.T) { seedContract(t, memDB, addr, felt.Zero, replacedClass) seedDeprecatedClassHashHistory(t, memDB, addr, replaceBlock, deployClass) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -194,7 +195,7 @@ func TestMigrate_Nonce_Updated(t *testing.T) { seedDeprecatedNonceHistory(t, memDB, addr, 200, felt.Zero) seedDeprecatedNonceHistory(t, memDB, addr, 300, felt.FromUint64[felt.Felt](1)) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -234,7 +235,7 @@ func TestMigrate_Nonce_DeployOnly(t *testing.T) { addr := felt.FromUint64[felt.Felt](1) seedContract(t, memDB, addr, felt.Zero, felt.FromUint64[felt.Felt](170)) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -269,7 +270,7 @@ func TestMigrate_Storage_MultiWrite(t *testing.T) { seedDeprecatedStorageTrie(t, memDB, addr, map[felt.Felt]felt.Felt{slot: headVal}) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -313,7 +314,7 @@ func TestMigrate_Storage_SingleWrite(t *testing.T) { seedDeprecatedStorageHistory(t, memDB, addr, slot, 100, felt.Zero) seedDeprecatedStorageTrie(t, memDB, addr, map[felt.Felt]felt.Felt{slot: v}) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -345,7 +346,7 @@ func TestMigrate_Idempotent(t *testing.T) { seedDeprecatedNonceHistory(t, memDB, addr, 200, felt.Zero) for range 3 { - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -386,7 +387,7 @@ func TestMigrate_ClassHash_ResumeFromPartial(t *testing.T) { require.NoError(t, state.WriteClassHashHistory(memDB, &addr, deployHeight, &deployClass)) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -431,7 +432,7 @@ func TestMigrate_Storage_ZeroedSlotHasNoLeaf(t *testing.T) { seedDeprecatedStorageTrie(t, memDB, addr, map[felt.Felt]felt.Felt{keptSlot: keptHead}) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -494,7 +495,7 @@ func TestMigrate_Storage_ManyEntries(t *testing.T) { } seedDeprecatedStorageTrie(t, memDB, addr, headValues) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -550,7 +551,7 @@ func TestMigrate_Storage_MultiAddress(t *testing.T) { }) } - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -603,13 +604,13 @@ func TestMigrate_CancelledContext_ResumesCleanly(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - res, err := Migrator{}.Migrate(ctx, memDB, &networks.Sepolia, log.NewNopZapLogger()) + res, err := statehistory.Migrator{}.Migrate(ctx, memDB, &networks.Sepolia, log.NewNopZapLogger()) require.Error(t, err) require.ErrorIs(t, err, context.Canceled) require.NotNil(t, res, "shouldRerun sentinel must not be nil") require.Empty(t, res, "shouldRerun is a non-nil empty slice") - res, err = Migrator{}.Migrate( + res, err = statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -642,7 +643,7 @@ func TestMigrate_Storage_ResumeFromPartial(t *testing.T) { require.NoError(t, state.WriteStorageHistory(memDB, &addr, &slot, 100, &firstVal)) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -680,7 +681,7 @@ func TestMigrate_AddressWithEmptyHistoryForOnePhase(t *testing.T) { seedContract(t, memDB, addr, felt.Zero, classHash) seedDeprecatedClassHashHistory(t, memDB, addr, 300, deployClass) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, @@ -733,7 +734,7 @@ func TestMigrate_Storage_InterleavedZeroedSlots(t *testing.T) { slot4: head4, }) - res, err := Migrator{}.Migrate( + res, err := statehistory.Migrator{}.Migrate( context.Background(), memDB, &networks.Sepolia, From bb292fd4ce91fa29c3cf1b3bd9cd4db17a3146f9 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 19 May 2026 13:18:45 +0200 Subject: [PATCH 07/14] feat(migration): trie migration --- core/trie2/trieutils/accessors.go | 31 +++ migration/trie/codec.go | 75 ++++++++ migration/trie/codec_test.go | 160 ++++++++++++++++ migration/trie/committer.go | 38 ++++ migration/trie/counter.go | 47 +++++ migration/trie/dfsmigration.go | 231 +++++++++++++++++++++++ migration/trie/dfsmigration_test.go | 273 +++++++++++++++++++++++++++ migration/trie/hashpool.go | 59 ++++++ migration/trie/hashworker.go | 188 +++++++++++++++++++ migration/trie/hashworker_test.go | 154 +++++++++++++++ migration/trie/ingestor.go | 95 ++++++++++ migration/trie/trie.go | 247 ++++++++++++++++++++++++ migration/trie/trie_test.go | 281 ++++++++++++++++++++++++++++ node/migration.go | 4 +- 14 files changed, 1882 insertions(+), 1 deletion(-) create mode 100644 migration/trie/codec.go create mode 100644 migration/trie/codec_test.go create mode 100644 migration/trie/committer.go create mode 100644 migration/trie/counter.go create mode 100644 migration/trie/dfsmigration.go create mode 100644 migration/trie/dfsmigration_test.go create mode 100644 migration/trie/hashpool.go create mode 100644 migration/trie/hashworker.go create mode 100644 migration/trie/hashworker_test.go create mode 100644 migration/trie/ingestor.go create mode 100644 migration/trie/trie.go create mode 100644 migration/trie/trie_test.go diff --git a/core/trie2/trieutils/accessors.go b/core/trie2/trieutils/accessors.go index 08f96c7166..bbfccc91ce 100644 --- a/core/trie2/trieutils/accessors.go +++ b/core/trie2/trieutils/accessors.go @@ -39,6 +39,37 @@ func WriteNodeByPath( return w.Put(nodeKeyByPath(bucket, owner, path, isLeaf), blob) } +// MaxNodeKeySize is the maximum byte length of a new-format trie node key: +// 1 (prefix) + 32 (owner, optional) + 1 (nodeType) + MaxBitArraySize (path). +const MaxNodeKeySize = 1 + 32 + 1 + MaxBitArraySize + +// EncodeNodeKey writes the node key into dst and returns the number of bytes written. +// dst must have at least MaxNodeKeySize bytes of capacity. +func EncodeNodeKey(dst []byte, bucket db.Bucket, owner *felt.Address, path *Path, isLeaf bool) int { + n := 0 + dst[n] = byte(bucket) + n++ + + if !felt.IsZero(owner) { + ownerBytes := owner.Bytes() + copy(dst[n:], ownerBytes[:]) + n += 32 + } + + if isLeaf { + dst[n] = leaf.Byte() + } else { + dst[n] = nonLeaf.Byte() + } + n++ + + pathBytes := path.EncodedBytes() + copy(dst[n:], pathBytes) + n += len(pathBytes) + + return n +} + func DeleteNodeByPath( w db.KeyValueWriter, bucket db.Bucket, diff --git a/migration/trie/codec.go b/migration/trie/codec.go new file mode 100644 index 0000000000..77fb642efc --- /dev/null +++ b/migration/trie/codec.go @@ -0,0 +1,75 @@ +package trie + +import ( + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2/trieutils" +) + +const ( + binaryNodeTag byte = 0x01 + edgeNodeTag byte = 0x02 + + valueNodeBlobSize = felt.Bytes + binaryNodeBlobSize = 1 + 2*felt.Bytes + edgeNodeMinSize = 1 + felt.Bytes + 1 + edgeNodeMaxSize = 1 + felt.Bytes + trieutils.MaxBitArraySize +) + +func toNewPath(old *trie.BitArray) trieutils.Path { + b := old.Bytes() + var p trieutils.Path + p.SetBytes(old.Len(), b[:]) + return p +} + +func encodeValueNode(value *felt.Felt) [valueNodeBlobSize]byte { + return value.Bytes() +} + +func encodeBinaryNode(leftEdgeHash, rightEdgeHash *felt.Felt) [binaryNodeBlobSize]byte { + var blob [binaryNodeBlobSize]byte + blob[0] = binaryNodeTag + lb := leftEdgeHash.Bytes() + rb := rightEdgeHash.Bytes() + copy(blob[1:], lb[:]) + copy(blob[1+felt.Bytes:], rb[:]) + return blob +} + +func encodeEdgeNode(childHash *felt.Felt, pathSeg *trieutils.Path) []byte { + encoded := pathSeg.EncodedBytes() + var arr [edgeNodeMaxSize]byte + arr[0] = edgeNodeTag + h := childHash.Bytes() + copy(arr[1:], h[:]) + copy(arr[1+felt.Bytes:], encoded) + return arr[:1+felt.Bytes+len(encoded)] +} + +func encodeEdgeNodeInto(dst []byte, childHash *felt.Felt, pathSeg *trieutils.Path) int { + encoded := pathSeg.EncodedBytes() + dst[0] = edgeNodeTag + h := childHash.Bytes() + copy(dst[1:], h[:]) + copy(dst[1+felt.Bytes:], encoded) + return 1 + felt.Bytes + len(encoded) +} + +func computeEdgeHash(childHash *felt.Felt, path *trieutils.Path, hashFn crypto.HashFn) felt.Felt { + if path.Len() == 0 { + return *childHash + } + pathFelt := path.Felt() + h := hashFn(childHash, &pathFelt) + lenFelt := felt.FromUint64[felt.Felt](uint64(path.Len())) + h.Add(&h, &lenFelt) + return h +} + +func compressedSegment(childFullPath *trie.BitArray, parentLen uint8) trieutils.Path { + var seg trie.BitArray + seg.LSBs(childFullPath, parentLen+1) + return toNewPath(&seg) +} diff --git a/migration/trie/codec_test.go b/migration/trie/codec_test.go new file mode 100644 index 0000000000..e468af7aa1 --- /dev/null +++ b/migration/trie/codec_test.go @@ -0,0 +1,160 @@ +package trie + +import ( + "bytes" + "fmt" + "testing" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func makeOldPath(length uint8, val uint64) trie.BitArray { + var ba trie.BitArray + ba.SetUint64(length, val) + return ba +} + +func makeNewPath(length uint8, val uint64) trieutils.Path { + old := makeOldPath(length, val) + return toNewPath(&old) +} + +func TestToNewPath_PreservesLengthAndBits(t *testing.T) { + for _, tc := range []struct { + name string + len uint8 + val uint64 + }{ + {"zero length", 0, 0}, + {"single bit 0", 1, 0}, + {"single bit 1", 1, 1}, + {"8 bits", 8, 0xAB}, + {"251 bits", 251, 0xDEADBEEF}, + } { + t.Run(tc.name, func(t *testing.T) { + old := makeOldPath(tc.len, tc.val) + np := toNewPath(&old) + assert.Equal(t, tc.len, np.Len()) + assert.Equal(t, old.Bytes(), np.Bytes()) + }) + } +} + +func TestEncodeValueNode(t *testing.T) { + var v felt.Felt + v.SetUint64(0xCAFEBABE) + blob := encodeValueNode(&v) + assert.Equal(t, valueNodeBlobSize, len(blob)) + assert.Equal(t, v.Bytes(), blob) +} + +func TestEncodeBinaryNode(t *testing.T) { + var l, r felt.Felt + l.SetUint64(1) + r.SetUint64(2) + blob := encodeBinaryNode(&l, &r) + assert.Equal(t, binaryNodeBlobSize, len(blob)) + assert.Equal(t, binaryNodeTag, blob[0]) + lb := l.Bytes() + rb := r.Bytes() + assert.Equal(t, lb[:], blob[1:33]) + assert.Equal(t, rb[:], blob[33:65]) +} + +func TestEncodeEdgeNode(t *testing.T) { + for _, pathLen := range []uint8{0, 1, 10, 250} { + t.Run(fmt.Sprintf("pathLen=%d", pathLen), func(t *testing.T) { + var childHash felt.Felt + childHash.SetUint64(42) + seg := makeNewPath(pathLen, 0b101) + blob := encodeEdgeNode(&childHash, &seg) + + require.Greater(t, len(blob), 0) + assert.Equal(t, edgeNodeTag, blob[0]) + + var got felt.Felt + got.SetBytes(blob[1:33]) + assert.Equal(t, childHash, got) + + activeBytes := (int(pathLen) + 7) / 8 + assert.Equal(t, 1+felt.Bytes+activeBytes+1, len(blob)) + }) + } +} + +func TestComputeEdgeHash_ZeroLengthPath_ReturnChildHashUnchanged(t *testing.T) { + var childHash felt.Felt + childHash.SetUint64(99) + var path trieutils.Path + result := computeEdgeHash(&childHash, &path, crypto.Pedersen) + assert.True(t, result.Equal(&childHash)) +} + +func TestComputeEdgeHash_MatchesEdgeNodeHash(t *testing.T) { + for _, hashFn := range []crypto.HashFn{crypto.Pedersen, crypto.Poseidon} { + for _, pathLen := range []uint8{1, 10, 250} { + var childHash felt.Felt + childHash.SetUint64(123456) + seg := makeNewPath(pathLen, 0b1011) + + got := computeEdgeHash(&childHash, &seg, hashFn) + + hashNode := trienode.HashNode(childHash) + edge := &trienode.EdgeNode{Child: &hashNode, Path: &seg} + want := edge.Hash(hashFn) + + assert.True(t, got.Equal(&want), "pathLen=%d", pathLen) + } + } +} + +func TestCompressedSegment_Length(t *testing.T) { + for _, tc := range []struct { + parentLen uint8 + segLen uint8 + }{ + {0, 0}, + {0, 5}, + {10, 20}, + {100, 50}, + } { + childLen := tc.parentLen + 1 + tc.segLen + child := makeOldPath(childLen, 0b111) + seg := compressedSegment(&child, tc.parentLen) + assert.Equal(t, tc.segLen, seg.Len(), "parentLen=%d segLen=%d", tc.parentLen, tc.segLen) + } +} + +func TestOldTriePrefix_GlobalTrie(t *testing.T) { + desc := TrieDesc{OldBucket: db.ClassesTrie, Owner: felt.Address{}} + prefix := oldTriePrefix(desc) + assert.Equal(t, []byte{byte(db.ClassesTrie)}, prefix) +} + +func TestOldTriePrefix_StorageTrie(t *testing.T) { + var ownerFelt felt.Felt + ownerFelt.SetUint64(42) + owner := felt.Address(ownerFelt) + desc := TrieDesc{OldBucket: db.ContractStorage, Owner: owner} + prefix := oldTriePrefix(desc) + assert.Equal(t, byte(db.ContractStorage), prefix[0]) + assert.Equal(t, 1+felt.Bytes, len(prefix)) +} + +func TestParseOldPath_RoundTrip(t *testing.T) { + original := makeOldPath(17, 0b10101) + var buf bytes.Buffer + _, err := original.Write(&buf) + require.NoError(t, err) + var parsed trie.BitArray + require.NoError(t, parsed.UnmarshalBinary(buf.Bytes())) + assert.Equal(t, original.Len(), parsed.Len()) + assert.Equal(t, original.Bytes(), parsed.Bytes()) +} diff --git a/migration/trie/committer.go b/migration/trie/committer.go new file mode 100644 index 0000000000..d3a29ecbd9 --- /dev/null +++ b/migration/trie/committer.go @@ -0,0 +1,38 @@ +package trie + +import ( + "fmt" + + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/migration/semaphore" + "github.com/NethermindEth/juno/utils/log" +) + +type committer struct { + batchSem semaphore.ResourceSemaphore[db.Batch] + counter counter +} + +func newCommitter( + logger log.StructuredLogger, + batchSem semaphore.ResourceSemaphore[db.Batch], +) *committer { + return &committer{ + batchSem: batchSem, + counter: newCounter(logger, timeLogRate), + } +} + +func (c *committer) Run(_ int, t task, _ chan<- struct{}) error { + byteSize := uint64(t.batch.Size()) + if err := t.batch.Write(); err != nil { + return fmt.Errorf("trie migration: batch write failed: %w", err) + } + c.counter.log(byteSize, t.tries) + c.batchSem.Put() + return nil +} + +func (c *committer) Done(int, chan<- struct{}) error { + return nil +} diff --git a/migration/trie/counter.go b/migration/trie/counter.go new file mode 100644 index 0000000000..2d64cb4b5a --- /dev/null +++ b/migration/trie/counter.go @@ -0,0 +1,47 @@ +package trie + +import ( + "time" + + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/utils/log" + "go.uber.org/zap" +) + +type counter struct { + logger log.StructuredLogger + timeLogRate time.Duration + start time.Time + size uint64 + tries uint64 +} + +func newCounter(logger log.StructuredLogger, timeLogRate time.Duration) counter { + return counter{ + logger: logger, + timeLogRate: timeLogRate, + start: time.Now(), + } +} + +func (c *counter) log(byteSize uint64, tries int) { + c.size += byteSize + c.tries += uint64(tries) + + now := time.Now() + elapsed := now.Sub(c.start).Seconds() + if elapsed > float64(c.timeLogRate.Seconds()) { + mbs := float64(c.size) / float64(db.Megabyte) + c.logger.Info( + "write speed", + zap.Float64("MB", mbs), + zap.Float64("MB/s", mbs/elapsed), + zap.Uint64("tries", c.tries), + zap.Float64("tries/s", float64(c.tries)/elapsed), + zap.Float64("time", elapsed), + ) + c.start = now + c.size = 0 + c.tries = 0 + } +} diff --git a/migration/trie/dfsmigration.go b/migration/trie/dfsmigration.go new file mode 100644 index 0000000000..65e28421d7 --- /dev/null +++ b/migration/trie/dfsmigration.go @@ -0,0 +1,231 @@ +package trie + +import ( + "fmt" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" +) + +type dfsMigrator struct { + parallelDispatch bool + pool *hashWorkerPool +} + +func newDFSMigrator(parallelDispatch bool, pool *hashWorkerPool) *dfsMigrator { + return &dfsMigrator{parallelDispatch: parallelDispatch, pool: pool} +} + +func (m *dfsMigrator) Migrate( + r db.KeyValueReader, + batch db.Batch, + desc TrieDesc, + flush FlushBatchFn, + stack []dfsFrame, +) (db.Batch, []dfsFrame, error) { + stack = stack[:0] + if desc.RootPath == nil { + return batch, stack, nil + } + prefix := oldTriePrefix(desc) + rootPath := desc.RootPath + sched := newHashScheduler(desc.HashFn, + m.parallelDispatch, + desc.NewBucket, &desc.Owner, m.pool) + + rootHash, batch, stack, err := traverse(r, prefix, rootPath, sched, batch, flush, stack) + if err != nil { + return batch, stack, err + } + if err = sched.sync(batch); err != nil { + return batch, stack, err + } + if rootPath.Len() > 0 { + if err = writeRootEdge(rootPath, rootHash, sched, batch); err != nil { + return batch, stack, err + } + } + return batch, stack, nil +} + +type traverseStackState uint8 + +const ( + readNode traverseStackState = iota + leftSubtreeDone + rightSubtreeDone +) + +type dfsFrame struct { + oldPath trie.BitArray + left trie.BitArray + right trie.BitArray + value felt.Felt + leftHash felt.Felt + isLeaf bool + state traverseStackState + binaryDepth uint8 +} + +func traverse( + r db.KeyValueReader, + prefix []byte, + start *trie.BitArray, + sched *hashScheduler, + batch db.Batch, + flush FlushBatchFn, + stack []dfsFrame, +) (felt.Felt, db.Batch, []dfsFrame, error) { + stack = stack[:1] + stack[0] = dfsFrame{oldPath: *start} + var lastHash felt.Felt + for len(stack) > 0 { + top := &stack[len(stack)-1] + switch top.state { + case readNode: + if err := parseNodeInto(r, prefix, top); err != nil { + return felt.Felt{}, batch, stack, err + } + newPath := toNewPath(&top.oldPath) + if top.isLeaf { + if err := processLeaf(newPath, &top.value, sched, batch); err != nil { + return felt.Felt{}, batch, stack, err + } + lastHash = top.value + stack = stack[:len(stack)-1] + } else { + left := top.left + top.state = leftSubtreeDone + stack = pushFrame(stack, left) + } + case leftSubtreeDone: + top.leftHash = lastHash + right := top.right + top.state = rightSubtreeDone + stack = pushFrame(stack, right) + case rightSubtreeDone: + newPath := toNewPath(&top.oldPath) + if err := processBinary( + newPath, + &top.left, + &top.right, + top.leftHash, + lastHash, + sched, + batch, + ); err != nil { + return felt.Felt{}, batch, stack, err + } + lastHash = top.value + stack = stack[:len(stack)-1] + } + batch = flush(batch) + } + return lastHash, batch, stack, nil +} + +func pushFrame(stack []dfsFrame, oldPath trie.BitArray) []dfsFrame { + n := len(stack) + stack = stack[:n+1] + stack[n] = dfsFrame{oldPath: oldPath} + return stack +} + +const maxOldKeySize = 1 + 32 + 1 + 32 + +func encodeOldPath(path *trie.BitArray, dst []byte) int { + pathLen := path.Len() + b := path.Bytes() + activeBytes := (uint(pathLen) + 7) / 8 + dst[0] = pathLen + copy(dst[1:], b[32-activeBytes:]) + return int(activeBytes) + 1 +} + +func parseNodeInto(r db.KeyValueReader, prefix []byte, frame *dfsFrame) error { + var arr [maxOldKeySize]byte + n := copy(arr[:], prefix) + n += encodeOldPath(&frame.oldPath, arr[n:]) + return r.Get(arr[:n], func(val []byte) error { + return parseNodeData(val, &frame.value, &frame.left, &frame.right, &frame.isLeaf) + }) +} + +func parseNodeData(data []byte, value *felt.Felt, left, right *trie.BitArray, isLeaf *bool) error { + if len(data) < felt.Bytes { + return fmt.Errorf("trie: node data too short (%d bytes)", len(data)) + } + *value = felt.FromBytes[felt.Felt](data[:felt.Bytes]) + data = data[felt.Bytes:] + if len(data) == 0 { + *isLeaf = true + return nil + } + *isLeaf = false + if err := left.UnmarshalBinary(data); err != nil { + return fmt.Errorf("trie: unmarshalling left path: %w", err) + } + data = data[left.EncodedLen():] + if err := right.UnmarshalBinary(data); err != nil { + return fmt.Errorf("trie: unmarshalling right path: %w", err) + } + return nil +} + +func processLeaf( + path trieutils.Path, + value *felt.Felt, + sched *hashScheduler, + batch db.Batch, +) error { + var buf [trieutils.MaxNodeKeySize + valueNodeBlobSize]byte + keyLen := trieutils.EncodeNodeKey(buf[:], sched.bucket, &sched.owner, &path, true) + blob := encodeValueNode(value) + copy(buf[keyLen:], blob[:]) + return batch.Put(buf[:keyLen], buf[keyLen:keyLen+valueNodeBlobSize]) +} + +func processBinary( + parentPath trieutils.Path, + left, right *trie.BitArray, + leftChildHash, rightChildHash felt.Felt, + sched *hashScheduler, + batch db.Batch, +) error { + leftSeg := compressedSegment(left, parentPath.Len()) + rightSeg := compressedSegment(right, parentPath.Len()) + leftFull := toNewPath(left) + rightFull := toNewPath(right) + var leftEdgePath, rightEdgePath trieutils.Path + leftEdgePath.MSBs(&leftFull, parentPath.Len()+1) + rightEdgePath.MSBs(&rightFull, parentPath.Len()+1) + return sched.schedule(edgeHashJob{ + leftChildHash: leftChildHash, + leftSeg: leftSeg, + rightChildHash: rightChildHash, + rightSeg: rightSeg, + parentPath: parentPath, + }, batch) +} + +func writeRootEdge( + rootPath *trie.BitArray, + childHash felt.Felt, + sched *hashScheduler, + batch db.Batch, +) error { + seg := toNewPath(rootPath) + blob := encodeEdgeNode(&childHash, &seg) + return trieutils.WriteNodeByPath(batch, sched.bucket, &sched.owner, &trieutils.Path{}, false, blob) +} + +func oldTriePrefix(desc TrieDesc) []byte { + if desc.OldBucket == db.ContractStorage { + ownerFelt := felt.Felt(desc.Owner) + ownerBytes := ownerFelt.Bytes() + return desc.OldBucket.Key(ownerBytes[:]) + } + return desc.OldBucket.Key() +} diff --git a/migration/trie/dfsmigration_test.go b/migration/trie/dfsmigration_test.go new file mode 100644 index 0000000000..2ff05a252e --- /dev/null +++ b/migration/trie/dfsmigration_test.go @@ -0,0 +1,273 @@ +package trie + +import ( + "context" + "math/rand" + "testing" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/triedb/rawdb" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/NethermindEth/juno/utils/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type leafMap map[felt.Felt]felt.Felt + +func nopLogger() log.StructuredLogger { return log.NewNopZapLogger() } +func noFlush(current db.Batch) db.Batch { return current } + +type trieCase struct { + name string + oldBucket db.Bucket + newBucket db.Bucket + owner felt.Address + oldBuildPrefix func(owner felt.Address) []byte + newTrieID func(owner felt.Address) trieutils.TrieID + hashFn crypto.HashFn + //nolint:staticcheck // Necessary for old state + buildOldFn func(db.IndexedBatch, []byte, uint8) (*trie.Trie, error) +} + +var allBackends = []struct { + name string + backend *dfsMigrator +}{ + {"dfsSerial", &dfsMigrator{parallelDispatch: false}}, + {"dfsParallel", &dfsMigrator{parallelDispatch: true}}, +} + +var trieCases = []trieCase{ + { + name: "ClassTrie", + oldBucket: db.ClassesTrie, + newBucket: db.ClassTrie, + oldBuildPrefix: func(_ felt.Address) []byte { return []byte{byte(db.ClassesTrie)} }, + newTrieID: func(_ felt.Address) trieutils.TrieID { + return trieutils.NewClassTrieID(felt.StateRootHash(felt.One)) + }, + hashFn: crypto.Poseidon, + buildOldFn: trie.NewTriePoseidon, + }, + { + name: "ContractTrie", + oldBucket: db.StateTrie, + newBucket: db.ContractTrieContract, + oldBuildPrefix: func(_ felt.Address) []byte { return []byte{byte(db.StateTrie)} }, + newTrieID: func(_ felt.Address) trieutils.TrieID { + return trieutils.NewContractTrieID(felt.StateRootHash(felt.One)) + }, + hashFn: crypto.Pedersen, + buildOldFn: trie.NewTriePedersen, + }, + { + name: "StorageTrie", + oldBucket: db.ContractStorage, + newBucket: db.ContractTrieStorage, + owner: felt.FromUint64[felt.Address](42), + oldBuildPrefix: func(owner felt.Address) []byte { + ownerFelt := felt.Felt(owner) + ownerBytes := ownerFelt.Bytes() + return db.ContractStorage.Key(ownerBytes[:]) + }, + newTrieID: func(owner felt.Address) trieutils.TrieID { + return trieutils.NewContractStorageTrieID(felt.StateRootHash(felt.One), owner) + }, + hashFn: crypto.Pedersen, + buildOldFn: trie.NewTriePedersen, + }, +} + +// randomLeaves generates n distinct leaf key-value pairs using a fixed seed, +// with keys spread across the full 251-bit felt range for structural variety. +func randomLeaves(n int, seed int64) leafMap { + rng := rand.New(rand.NewSource(seed)) + leaves := make(leafMap, n) + var kb, vb [32]byte + for len(leaves) < n { + rng.Read(kb[:]) + rng.Read(vb[:]) + // Clear the top 5 bits so all keys are safely below the StarkNet prime (~2^251+δ). + kb[0] &= 0x07 + k := felt.FromBytes[felt.Felt](kb[:]) + v := felt.FromBytes[felt.Felt](vb[:]) + leaves[k] = v + } + return leaves +} + +var transcoderCases = []struct { + name string + leaves leafMap +}{ + // --- trivial structural cases --- + {"empty trie", nil}, + {"single leaf", leafMap{ + felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](100), + }}, + + // --- two-leaf cases probing specific split depths --- + + // Keys 2 and 3 differ only in the last path bit (bit 250 of felt = path bit 250). + // Root edge spans 250 bits; binary node is at maximum depth. + {"deep split", leafMap{ + felt.FromUint64[felt.Felt](2): felt.FromUint64[felt.Felt](10), + felt.FromUint64[felt.Felt](3): felt.FromUint64[felt.Felt](20), + }}, + + // One key < 2^250 (trie path bit[0]=0), one key ≥ 2^250 (trie path bit[0]=1). + // Root is a binary node — rootPath.Len()==0, no root edge written. + // This path is never hit by sequential small-integer keys. + {"left right split", leafMap{ + felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](10), + felt.FromBytes[felt.Felt]([]byte{0x04}): felt.FromUint64[felt.Felt](20), // 2^250 + }}, + + // --- four leaves covering all 2-bit path prefixes (00, 01, 10, 11) --- + // Root is a binary node; left and right subtrees each contain a binary node. + // Tests two levels of binary processing with rootPath.Len()==0. + {"full depth 2 tree", leafMap{ + felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](10), // 00... + felt.FromBytes[felt.Felt]([]byte{0x02}): felt.FromUint64[felt.Felt](20), // 01... (2^249) + felt.FromBytes[felt.Felt]([]byte{0x04}): felt.FromUint64[felt.Felt](30), // 10... (2^250) + felt.FromBytes[felt.Felt]([]byte{0x06}): felt.FromUint64[felt.Felt](40), // 11... (2^250+2^249) + }}, + + // --- larger sequential case for basic load --- + {"hundred sequential leaves", func() leafMap { + leaves := make(leafMap, 100) + for i := 1; i <= 100; i++ { + leaves[felt.FromUint64[felt.Felt](uint64(i))] = felt.FromUint64[felt.Felt](uint64(i) * 7) + } + return leaves + }()}, + + // --- random leaves spanning the full 251-bit space --- + // Keys are evenly distributed across all trie depths, exercising every code + // path: balanced binary nodes, varied edge lengths, and batch flush thresholds. + {"random 1000 leaves", randomLeaves(1000, 42)}, +} + +// TestMigrateTrieMatchesNativeTrie2 verifies that each backend produces byte-for-byte +// identical DB output to a natively-built trie2 for all three trie types and all leaf +// counts. This catches encoding bugs that root-hash comparison cannot detect. +func TestMigrationEndToEnd(t *testing.T) { + type testCase struct { + name string + tc trieCase + backend *dfsMigrator + leaves leafMap + } + + var cases []testCase + for _, tc := range trieCases { + for _, b := range allBackends { + for _, lc := range transcoderCases { + cases = append(cases, testCase{ + name: tc.name + "/" + b.name + "/" + lc.name, + tc: tc, + backend: b.backend, + leaves: lc.leaves, + }) + } + } + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + prefix := c.tc.oldBuildPrefix(c.tc.owner) + + migratedDB := memory.New() + buildDeprecatedTrie(t, migratedDB, c.leaves, c.tc.buildOldFn, prefix) + runMigration(context.Background(), migratedDB, nopLogger()) + + nativeDB := memory.New() + buildTrie(t, nativeDB, c.leaves, + c.tc.newTrieID(c.tc.owner), c.tc.hashFn, c.tc.newBucket) + + assert.Equal(t, + allKeysUnder(t, nativeDB, c.tc.newBucket), + allKeysUnder(t, migratedDB, c.tc.newBucket)) + }) + } +} + +func buildDeprecatedTrie( + t *testing.T, + database db.KeyValueStore, + leaves leafMap, + //nolint:staticcheck // Necessary for old state + trieFn func(db.IndexedBatch, []byte, uint8) (*trie.Trie, error), + prefix []byte, +) felt.Felt { + t.Helper() + //nolint:staticcheck // Necessary for old state + txn := database.NewIndexedBatch() + tr, err := trieFn(txn, prefix, 251) + require.NoError(t, err) + for key, value := range leaves { + _, err := tr.Put(&key, &value) + require.NoError(t, err) + } + root, err := tr.Root() + require.NoError(t, err) + require.NoError(t, tr.Commit()) + require.NoError(t, txn.Write()) + return root +} + +// buildNativeTrie2 builds a trie2 natively from leaves and persists it to kvStore. +// newBucket distinguishes class trie (db.ClassTrie) from contract/storage tries — +// it controls which Update argument the NodeSet is passed as. +func buildTrie( + t *testing.T, + kvStore db.KeyValueStore, + leaves leafMap, + id trieutils.TrieID, + hashFn crypto.HashFn, + newBucket db.Bucket, +) { + t.Helper() + rawDB := rawdb.New(kvStore) + tr, err := trie2.New(id, 251, hashFn, rawDB) + require.NoError(t, err) + for key, value := range leaves { + require.NoError(t, tr.Update(&key, &value)) + } + root, nodes := tr.Commit() + if nodes == nil { + return // empty trie — nothing to persist + } + mergeSet := trienode.NewMergeNodeSet(nodes) + var zero felt.StateRootHash + stateRoot := felt.StateRootHash(root) + batch := kvStore.NewBatch() + if newBucket == db.ClassTrie { + require.NoError(t, rawDB.Update(&stateRoot, &zero, 0, mergeSet, nil, batch)) + } else { + require.NoError(t, rawDB.Update(&stateRoot, &zero, 0, nil, mergeSet, batch)) + } + require.NoError(t, batch.Write()) +} + +func allKeysUnder(t *testing.T, r db.KeyValueReader, bucket db.Bucket) map[string][]byte { + t.Helper() + prefix := bucket.Key() + iter, err := r.NewIterator(prefix, true) + require.NoError(t, err) + defer iter.Close() + out := make(map[string][]byte) + for ok := iter.First(); ok; ok = iter.Next() { + val, err := iter.Value() + require.NoError(t, err) + out[string(iter.Key())] = val + } + return out +} diff --git a/migration/trie/hashpool.go b/migration/trie/hashpool.go new file mode 100644 index 0000000000..89710a33ee --- /dev/null +++ b/migration/trie/hashpool.go @@ -0,0 +1,59 @@ +package trie + +import ( + "sync" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" +) + +type hashWork struct { + hashFn crypto.HashFn + jobs []edgeHashJob + results []felt.Felt + wg *sync.WaitGroup +} + +type hashWorkerPool struct { + work chan hashWork + n int +} + +func newHashWorkerPool() *hashWorkerPool { + n := 4 + p := &hashWorkerPool{work: make(chan hashWork, n*2), n: n} + for range n { + go func() { + for w := range p.work { + for i := range w.jobs { + w.results[2*i] = computeEdgeHash(&w.jobs[i].leftChildHash, &w.jobs[i].leftSeg, w.hashFn) + w.results[2*i+1] = computeEdgeHash(&w.jobs[i].rightChildHash, &w.jobs[i].rightSeg, w.hashFn) + } + w.wg.Done() + } + }() + } + return p +} + +func (p *hashWorkerPool) submit( + hashFn crypto.HashFn, + jobs []edgeHashJob, + results []felt.Felt, +) <-chan struct{} { + done := make(chan struct{}) + go func() { + var wg sync.WaitGroup + chunkSize := max(1, (len(jobs)+p.n-1)/p.n) + for i := 0; i < len(jobs); i += chunkSize { + end := min(i+chunkSize, len(jobs)) + wg.Add(1) + p.work <- hashWork{hashFn, jobs[i:end], results[2*i : 2*end], &wg} + } + wg.Wait() + close(done) + }() + return done +} + +func (p *hashWorkerPool) close() { close(p.work) } diff --git a/migration/trie/hashworker.go b/migration/trie/hashworker.go new file mode 100644 index 0000000000..6cdbf3d28e --- /dev/null +++ b/migration/trie/hashworker.go @@ -0,0 +1,188 @@ +package trie + +import ( + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" +) + +type edgeHashJob struct { + leftChildHash, rightChildHash felt.Felt + leftSeg, rightSeg trieutils.Path + parentPath trieutils.Path +} + +type inFlightBatch struct { + jobs []edgeHashJob + results []felt.Felt + done <-chan struct{} +} + +type hashScheduler struct { + hashFn crypto.HashFn + parallel bool + bucket db.Bucket + owner felt.Address + pool *hashWorkerPool + + jobs []edgeHashJob + altJobs []edgeHashJob + results []felt.Felt + inFlightBuf inFlightBatch + hasInFlight bool +} + +func newHashScheduler( + hashFn crypto.HashFn, + parallel bool, + bucket db.Bucket, + owner *felt.Address, + pool *hashWorkerPool, +) *hashScheduler { + s := &hashScheduler{ + hashFn: hashFn, + parallel: parallel, + bucket: bucket, + owner: *owner, + pool: pool, + } + if parallel { + s.jobs = make([]edgeHashJob, 0, parallelHashBatchSize) + s.altJobs = make([]edgeHashJob, 0, parallelHashBatchSize) + s.results = make([]felt.Felt, 2*parallelHashBatchSize) + } + return s +} + +func (s *hashScheduler) reset( + hashFn crypto.HashFn, + parallel bool, + bucket db.Bucket, + owner *felt.Address, +) { + s.hashFn = hashFn + s.parallel = parallel + s.bucket = bucket + s.owner = *owner + s.hasInFlight = false + if parallel && s.jobs == nil { + s.jobs = make([]edgeHashJob, 0, parallelHashBatchSize) + s.altJobs = make([]edgeHashJob, 0, parallelHashBatchSize) + s.results = make([]felt.Felt, 2*parallelHashBatchSize) + } else if parallel { + s.jobs = s.jobs[:0] + s.altJobs = s.altJobs[:0] + } +} + +func (s *hashScheduler) schedule(job edgeHashJob, batch db.Batch) error { + if !s.parallel { + leftEdge := computeEdgeHash(&job.leftChildHash, &job.leftSeg, s.hashFn) + rightEdge := computeEdgeHash(&job.rightChildHash, &job.rightSeg, s.hashFn) + return s.writeBinaryAndEdges(job, &leftEdge, &rightEdge, batch) + } + s.jobs = append(s.jobs, job) + if len(s.jobs) >= parallelHashBatchSize { + return s.fire(batch) + } + return nil +} + +func (s *hashScheduler) fire(batch db.Batch) error { + if err := s.drainInFlight(batch); err != nil { + return err + } + results := s.results[:2*len(s.jobs)] + s.inFlightBuf = inFlightBatch{ + jobs: s.jobs, + results: results, + done: s.pool.submit(s.hashFn, s.jobs, results), + } + s.hasInFlight = true + s.jobs, s.altJobs = s.altJobs[:0], s.jobs + return nil +} + +func (s *hashScheduler) drainInFlight(batch db.Batch) error { + if !s.hasInFlight { + return nil + } + <-s.inFlightBuf.done + for i, job := range s.inFlightBuf.jobs { + err := s.writeBinaryAndEdges( + job, + &s.inFlightBuf.results[2*i], + &s.inFlightBuf.results[2*i+1], + batch, + ) + if err != nil { + return err + } + } + s.hasInFlight = false + return nil +} + +func (s *hashScheduler) sync(batch db.Batch) error { + if !s.parallel { + return nil + } + if err := s.drainInFlight(batch); err != nil { + return err + } + if len(s.jobs) > 0 { + results := s.results[:2*len(s.jobs)] + <-s.pool.submit(s.hashFn, s.jobs, results) + for i, job := range s.jobs { + if err := s.writeBinaryAndEdges( + job, + &results[2*i], + &results[2*i+1], + batch, + ); err != nil { + return err + } + } + s.jobs = s.jobs[:0] + } + return nil +} + +func (s *hashScheduler) writeBinaryAndEdges( + job edgeHashJob, + leftEdge, + rightEdge *felt.Felt, + batch db.Batch, +) error { + var buf [trieutils.MaxNodeKeySize + binaryNodeBlobSize]byte + keyLen := trieutils.EncodeNodeKey(buf[:], s.bucket, &s.owner, &job.parentPath, false) + blob := encodeBinaryNode(leftEdge, rightEdge) + copy(buf[keyLen:], blob[:]) + if err := batch.Put(buf[:keyLen], buf[keyLen:keyLen+binaryNodeBlobSize]); err != nil { + return err + } + + if job.leftSeg.Len() > 0 { + var leftEdgePath trieutils.Path + leftEdgePath.AppendBit(&job.parentPath, 0) + var ebuf [trieutils.MaxNodeKeySize + edgeNodeMaxSize]byte + kl := trieutils.EncodeNodeKey(ebuf[:], s.bucket, &s.owner, &leftEdgePath, false) + edgeBlob := encodeEdgeNodeInto(ebuf[kl:], &job.leftChildHash, &job.leftSeg) + if err := batch.Put(ebuf[:kl], ebuf[kl:kl+edgeBlob]); err != nil { + return err + } + } + + if job.rightSeg.Len() > 0 { + var rightEdgePath trieutils.Path + rightEdgePath.AppendBit(&job.parentPath, 1) + var ebuf [trieutils.MaxNodeKeySize + edgeNodeMaxSize]byte + kl := trieutils.EncodeNodeKey(ebuf[:], s.bucket, &s.owner, &rightEdgePath, false) + edgeBlob := encodeEdgeNodeInto(ebuf[kl:], &job.rightChildHash, &job.rightSeg) + if err := batch.Put(ebuf[:kl], ebuf[kl:kl+edgeBlob]); err != nil { + return err + } + } + return nil +} diff --git a/migration/trie/hashworker_test.go b/migration/trie/hashworker_test.go new file mode 100644 index 0000000000..7c02a47113 --- /dev/null +++ b/migration/trie/hashworker_test.go @@ -0,0 +1,154 @@ +package trie + +import ( + "testing" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func makeSimpleJob(parent *trieutils.Path, leftChild, rightChild *felt.Felt) edgeHashJob { + return edgeHashJob{ + leftChildHash: *leftChild, + rightChildHash: *rightChild, + parentPath: *parent, + } +} + +func expectedEdgeHash(child *felt.Felt, seg *trieutils.Path, hashFn crypto.HashFn) felt.Felt { + return computeEdgeHash(child, seg, hashFn) +} + +func TestHashScheduler_InlineWritesNodeToBatch(t *testing.T) { + memDB := memory.New() + batch := memDB.NewBatch() + + sched := newHashScheduler(crypto.Pedersen, false, db.ClassTrie, &felt.Address{}, nil) + parent := makeNewPath(3, 0b101) + left := felt.NewFromUint64[felt.Felt](1) + right := felt.NewFromUint64[felt.Felt](2) + + require.NoError(t, sched.schedule(makeSimpleJob(&parent, left, right), batch)) + require.NoError(t, batch.Write()) + + blob := encodeBinaryNode(left, right) + val, err := trieutils.GetNodeByPath(memDB, db.ClassTrie, &felt.Address{}, &parent, false) + require.NoError(t, err) + assert.Equal(t, blob[:], val) +} + +func TestHashScheduler_BatchedMatchesInlinePedersen(t *testing.T) { + testBatchedMatchesInline(t, crypto.Pedersen, 200) +} + +func TestHashScheduler_BatchedMatchesInlinePoseidon(t *testing.T) { + testBatchedMatchesInline(t, crypto.Poseidon, 200) +} + +func testBatchedMatchesInline(t *testing.T, hashFn crypto.HashFn, n int) { + t.Helper() + + pool := newHashWorkerPool() + t.Cleanup(pool.close) + + jobs := make([]edgeHashJob, n) + for i := range n { + l := felt.NewFromUint64[felt.Felt](uint64(i*2 + 1)) + r := felt.NewFromUint64[felt.Felt](uint64(i*2 + 2)) + path := makeNewPath(uint8(i%250), uint64(i)) + jobs[i] = makeSimpleJob(&path, l, r) + } + + inlineDB := memory.New() + { + batch := inlineDB.NewBatch() + sched := newHashScheduler(hashFn, false, db.ClassTrie, &felt.Address{}, nil) + for _, job := range jobs { + require.NoError(t, sched.schedule(job, batch)) + } + require.NoError(t, batch.Write()) + } + + batchDB := memory.New() + { + batch := batchDB.NewBatch() + sched := newHashScheduler(hashFn, true, db.ClassTrie, &felt.Address{}, pool) + for _, job := range jobs { + require.NoError(t, sched.schedule(job, batch)) + } + require.NoError(t, sched.sync(batch)) + require.NoError(t, batch.Write()) + } + + for i, job := range jobs { + inlineVal, err := trieutils.GetNodeByPath( + inlineDB, + db.ClassTrie, + &felt.Address{}, + &job.parentPath, + false, + ) + require.NoError(t, err, "path %d missing in inline", i) + batchVal, err := trieutils.GetNodeByPath( + batchDB, + db.ClassTrie, + &felt.Address{}, + &job.parentPath, + false, + ) + require.NoError(t, err, "path %d missing in batched", i) + assert.Equal(t, inlineVal, batchVal, "mismatch at path %d", i) + } +} + +func TestHashScheduler_AutoFlushAtBatchSize(t *testing.T) { + pool := newHashWorkerPool() + t.Cleanup(pool.close) + + memDB := memory.New() + batch := memDB.NewBatch() + + sched := newHashScheduler(crypto.Pedersen, true, db.ClassTrie, &felt.Address{}, pool) + for i := range parallelHashBatchSize { + path := makeNewPath(251, uint64(i+1)) + leftChildHash := felt.NewFromUint64[felt.Felt](uint64(i + 1)) + rightChildHash := felt.NewFromUint64[felt.Felt](uint64(i + 2)) + job := makeSimpleJob(&path, leftChildHash, rightChildHash) + require.NoError(t, sched.schedule(job, batch)) + } + // After filling a full batch, jobs are handed off to pool; local slice is reset + assert.Empty(t, sched.jobs, "jobs should be empty after auto-flush") + // Drain in-flight batch before pool.close() to avoid send-on-closed-channel + require.NoError(t, sched.sync(batch)) +} + +func TestHashScheduler_SingleJobDispatchesCorrectly(t *testing.T) { + pool := newHashWorkerPool() + t.Cleanup(pool.close) + + memDB := memory.New() + batch := memDB.NewBatch() + + sched := newHashScheduler(crypto.Pedersen, true, db.ClassTrie, &felt.Address{}, pool) + parent := makeNewPath(4, 0b1010) + l, r := felt.NewFromUint64[felt.Felt](7), felt.NewFromUint64[felt.Felt](13) + job := makeSimpleJob(&parent, l, r) + + require.NoError(t, sched.schedule(job, batch)) + require.NoError(t, sched.sync(batch)) + require.NoError(t, batch.Write()) + + blob := encodeBinaryNode(l, r) + val, err := trieutils.GetNodeByPath(memDB, db.ClassTrie, &felt.Address{}, &parent, false) + require.NoError(t, err) + assert.Equal(t, blob[:], val) +} + +func TestHashScheduler_LargeBatchDispatchesCorrectly(t *testing.T) { + testBatchedMatchesInline(t, crypto.Pedersen, parallelHashBatchSize*3) +} diff --git a/migration/trie/ingestor.go b/migration/trie/ingestor.go new file mode 100644 index 0000000000..1700eea0f3 --- /dev/null +++ b/migration/trie/ingestor.go @@ -0,0 +1,95 @@ +package trie + +import ( + "context" + "fmt" + + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/migration/semaphore" + "github.com/NethermindEth/juno/utils/log" +) + +type task struct { + batch db.Batch + tries int +} + +const dfsStackCap = 251 + +type ingestor struct { + logger log.StructuredLogger + database db.KeyValueReader + batchSemaphore semaphore.ResourceSemaphore[db.Batch] + tasks []task + pool *hashWorkerPool + dfsStacks [IngestorCount][]dfsFrame +} + +type FlushBatchFn func(db.Batch) db.Batch + +func newIngestor( + ctx context.Context, + database db.KeyValueReader, + batchSemaphore semaphore.ResourceSemaphore[db.Batch], + logger log.StructuredLogger, + pool *hashWorkerPool, +) (*ingestor, error) { + tasks := make([]task, IngestorCount) + for i := range tasks { + tasks[i] = task{batch: batchSemaphore.GetBlocking()} + } + + in := &ingestor{ + database: database, + batchSemaphore: batchSemaphore, + logger: logger, + tasks: tasks, + pool: pool, + } + for i := range in.dfsStacks { + in.dfsStacks[i] = make([]dfsFrame, 0, dfsStackCap) + } + return in, nil +} + +func (c *ingestor) Run(index int, desc TrieDesc, outputs chan<- task) error { + done, err := hasDestRoot(c.database, desc.NewBucket, &desc.Owner) + if err != nil { + return fmt.Errorf("hasDestRoot(%v, %x): %w", desc.NewBucket, desc.Owner, err) + } + if done { + return nil + } + + t := &c.tasks[index] + + flush := FlushBatchFn(func(current db.Batch) db.Batch { + if current.Size() < targetBatchByteSize { + return current + } + outputs <- task{batch: current, tries: t.tries} + t.tries = 0 + return c.batchSemaphore.GetBlocking() + }) + + migrator := newDFSMigrator(desc.NodeCount >= SmallTrieThreshold, c.pool) + t.batch, c.dfsStacks[index], err = migrator.Migrate( + c.database, + t.batch, + desc, + flush, + c.dfsStacks[index], + ) + if err != nil { + return err + } + + t.tries++ + t.batch = flush(t.batch) + return nil +} + +func (c *ingestor) Done(index int, outputs chan<- task) error { + outputs <- c.tasks[index] + return nil +} diff --git a/migration/trie/trie.go b/migration/trie/trie.go new file mode 100644 index 0000000000..e35ece9dba --- /dev/null +++ b/migration/trie/trie.go @@ -0,0 +1,247 @@ +package trie + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/NethermindEth/juno/blockchain/networks" + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/dbutils" + "github.com/NethermindEth/juno/migration" + "github.com/NethermindEth/juno/migration/pipeline" + "github.com/NethermindEth/juno/migration/semaphore" + "github.com/NethermindEth/juno/utils/log" +) + +const ( + batchByteSize = 128 * db.Megabyte + targetBatchByteSize = 96 * db.Megabyte + timeLogRate = 5 * time.Second + SmallTrieThreshold = 100_000 + parallelHashBatchSize = 16384 + IngestorCount = 4 +) + +var ( + shouldRerun = []byte{} + shouldNotRerun []byte +) + +type Migrator struct{} + +var _ migration.Migration = (*Migrator)(nil) + +func (*Migrator) Before([]byte) error { return nil } + +func (*Migrator) Migrate( + ctx context.Context, + database db.KeyValueStore, + _ *networks.Network, + logger log.StructuredLogger, +) ([]byte, error) { + needed, err := needsMigration(database) + if err != nil { + return shouldRerun, err + } + if !needed { + logger.Info("trie migration: no old-format data found, marking applied") + return shouldNotRerun, nil + } + + return runMigration(ctx, database, logger) +} + +func needsMigration(r db.KeyValueReader) (bool, error) { + for _, bucket := range []db.Bucket{db.ClassesTrie, db.StateTrie, db.ContractStorage} { + prefix := bucket.Key() + iter, err := r.NewIterator(prefix, true) + if err != nil { + return false, err + } + hasKeys := iter.First() + if err := iter.Close(); err != nil { + return false, err + } + if hasKeys { + return true, nil + } + } + return false, nil +} + +func runMigration( + ctx context.Context, + database db.KeyValueStore, + logger log.StructuredLogger, +) ([]byte, error) { + batchSem := semaphore.New(IngestorCount*2, func() db.Batch { + return database.NewBatchWithSize(batchByteSize) + }) + + pool := newHashWorkerPool() + + ing, err := newIngestor(ctx, database, batchSem, logger, pool) + if err != nil { + return shouldRerun, err + } + + tries, err := enumerateTries(database) + if err != nil { + return shouldRerun, err + } + + src := pipeline.Source(func(yield func(TrieDesc) bool) { + for _, d := range tries { + if !yield(d) { + return + } + } + }) + ingested := pipeline.New(src, IngestorCount, ing) + committed := pipeline.New( + ingested, + 1, + newCommitter(logger, batchSem), + ) + + _, wait := committed.Run(ctx) + res := wait() + if !res.IsDone { + pool.close() + return shouldRerun, res.Err + } + + pool.close() + + for _, bucket := range []db.Bucket{db.ClassesTrie, db.StateTrie, db.ContractStorage} { + start := bucket.Key() + end := dbutils.UpperBound(start) + if err := database.DeleteRange(start, end); err != nil { + return shouldRerun, fmt.Errorf("trie migration: cleanup DeleteRange for %v: %w", bucket, err) + } + } + logger.Info("trie migration: source buckets deleted") + + return shouldNotRerun, nil +} + +type TrieDesc struct { + OldBucket db.Bucket + NewBucket db.Bucket + Owner felt.Address + HashFn crypto.HashFn + NodeCount int + RootPath *trie.BitArray +} + +func enumerateTries(r db.KeyValueReader) ([]TrieDesc, error) { + var descs []TrieDesc + + for _, spec := range []struct { + oldBucket, newBucket db.Bucket + hashFn crypto.HashFn + }{ + {db.ClassesTrie, db.ClassTrie, crypto.Poseidon}, + {db.StateTrie, db.ContractTrieContract, crypto.Pedersen}, + } { + prefix := spec.oldBucket.Key() + it, err := r.NewIterator(prefix, true) + if err != nil { + return nil, fmt.Errorf("opening iterator for bucket %v: %w", spec.oldBucket, err) + } + var rootPath *trie.BitArray + count := 0 + for valid := it.First(); valid; valid = it.Next() { + key := it.Key() + if len(key) == len(prefix) { + if val, verr := it.Value(); verr == nil { + rootPath = parseRootPath(val) + } + } else { + count++ + } + } + it.Close() + descs = append(descs, TrieDesc{ + OldBucket: spec.oldBucket, + NewBucket: spec.newBucket, + HashFn: spec.hashFn, + NodeCount: count, + RootPath: rootPath, + }) + } + + storagePrefix := db.ContractStorage.Key() + storageIter, err := r.NewIterator(storagePrefix, true) + if err != nil { + return nil, fmt.Errorf("opening storage iterator: %w", err) + } + for valid := storageIter.First(); valid; valid = storageIter.Valid() { + key := storageIter.Key() + if len(key) < 1+felt.Bytes { + storageIter.Next() + continue + } + ownerFelt := felt.FromBytes[felt.Felt](key[1 : 1+felt.Bytes]) + owner := felt.Address(ownerFelt) + ownerBytes := ownerFelt.Bytes() + ownerPrefix := db.ContractStorage.Key(ownerBytes[:]) + + var rootPath *trie.BitArray + count := 0 + for storageIter.Valid() { + k := storageIter.Key() + if !bytes.HasPrefix(k, ownerPrefix) { + break + } + if len(k) == len(ownerPrefix) { + if val, verr := storageIter.Value(); verr == nil { + rootPath = parseRootPath(val) + } + } else { + count++ + } + storageIter.Next() + } + descs = append(descs, TrieDesc{ + OldBucket: db.ContractStorage, + NewBucket: db.ContractTrieStorage, + Owner: owner, + HashFn: crypto.Pedersen, + NodeCount: count, + RootPath: rootPath, + }) + } + storageIter.Close() + + return descs, nil +} + +func parseRootPath(val []byte) *trie.BitArray { + if len(val) == 0 { + return nil + } + var ba trie.BitArray + if err := ba.UnmarshalBinary(val); err != nil { + return nil + } + return &ba +} + +func hasDestRoot(r db.KeyValueReader, newBucket db.Bucket, owner *felt.Address) (bool, error) { + var emptyPath trieutils.Path + var buf [trieutils.MaxNodeKeySize]byte + + n := trieutils.EncodeNodeKey(buf[:], newBucket, owner, &emptyPath, false) + if exists, err := r.Has(buf[:n]); err != nil || exists { + return exists, err + } + n = trieutils.EncodeNodeKey(buf[:], newBucket, owner, &emptyPath, true) + return r.Has(buf[:n]) +} diff --git a/migration/trie/trie_test.go b/migration/trie/trie_test.go new file mode 100644 index 0000000000..d3c5322cf9 --- /dev/null +++ b/migration/trie/trie_test.go @@ -0,0 +1,281 @@ +package trie + +import ( + "bytes" + "context" + "slices" + "testing" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMigrate_FreshDBIsNoOp(t *testing.T) { + testDB := memory.New() + t.Cleanup(func() { testDB.Close() }) + + state, err := (&Migrator{}).Migrate(context.Background(), testDB, nil, nopLogger()) + require.NoError(t, err) + assert.Nil( + t, + state, + "fresh DB must mark migration applied (nil intermediate state) without doing work", + ) +} + +func TestMigrate_RunsWhenOldDataPresent(t *testing.T) { + leaves := randomLeaves(100, 7) + testDB, _, _, _, _ := buildFullDB(t, leaves) + + needed, err := needsMigration(testDB) + require.NoError(t, err) + require.True(t, needed, "precondition: DB has old-format data") + + state, err := (&Migrator{}).Migrate(context.Background(), testDB, nil, nopLogger()) + require.NoError(t, err) + assert.Nil(t, state, "completed migration must return nil intermediate state") + + stillNeeded, err := needsMigration(testDB) + require.NoError(t, err) + assert.False(t, stillNeeded, "old-format buckets should be empty after migration") +} + +func TestMigrationIsResumable(t *testing.T) { + leaves := randomLeaves(1000, 42) + + // Reference: full migration from scratch. + refDB, _, _, _, _ := buildFullDB(t, leaves) + _, err := runMigration(context.Background(), refDB, nopLogger()) + require.NoError(t, err) + + // Partial DB: both tries in old format initially. + partialDB, _, _, _, _ := buildFullDB(t, leaves) + + // Manually migrate only the class trie to simulate a mid-run interruption. + classPrefix := db.ClassesTrie.Key() + var classRootPath *trie.BitArray + require.NoError(t, partialDB.Get(classPrefix, func(val []byte) error { + classRootPath = parseRootPath(val) + return nil + })) + classDesc := TrieDesc{ + OldBucket: db.ClassesTrie, + NewBucket: db.ClassTrie, + HashFn: crypto.Poseidon, + NodeCount: len(leaves), + RootPath: classRootPath, + } + batch := partialDB.NewBatch() + + pool := newHashWorkerPool() + defer pool.close() + migrator := newDFSMigrator(true, pool) + stack := make([]dfsFrame, 0, dfsStackCap) + _, _, err = migrator.Migrate(partialDB, batch, classDesc, noFlush, stack) + require.NoError(t, err) + require.NoError(t, batch.Write()) + + done, err := hasDestRoot(partialDB, db.ClassTrie, &classDesc.Owner) + require.NoError(t, err) + assert.True(t, done, "class trie destination root should be present after partial migration") + + needed, err := needsMigration(partialDB) + require.NoError(t, err) + assert.True(t, needed, "migration should still be needed after partial work") + + // enumerateTries yields every discovered trie unfiltered; the + // already-done short-circuit lives in the ingestor now. We assert that + // the class trie *is* yielded here, and rely on the end-to-end equality + // check below to confirm the ingestor's idempotency guard skips it. + descsRemaining := collectTries(t, partialDB) + foundClass := false + for _, d := range descsRemaining { + if d.OldBucket == db.ClassesTrie { + foundClass = true + } + } + assert.True( + t, + foundClass, + "enumerateTries should yield class trie even when destination root is present", + ) + + // Resume: migration completes the remaining contract trie. + _, err = runMigration(context.Background(), partialDB, nopLogger()) + require.NoError(t, err) + + // Final state must match the reference full-run output for every new-format bucket. + for _, bucket := range []db.Bucket{db.ClassTrie, db.ContractTrieContract} { + refKeys := allKeysUnder(t, refDB, bucket) + resumedKeys := allKeysUnder(t, partialDB, bucket) + assert.Equal(t, refKeys, resumedKeys, + "resumed migration result differs from full run for bucket %v", bucket) + } +} + +// buildFullDB creates an old-format DB with class, contract, and one storage trie, +// all populated with the same leaf set. Returns the DB and the old-format root hashes. +func buildFullDB(t *testing.T, leaves leafMap) ( + database db.KeyValueStore, + classRoot felt.Felt, + contractRoot felt.Felt, + storageRoot felt.Felt, + owner felt.Address, +) { + t.Helper() + database = memory.New() + + classRoot = buildDeprecatedTrie(t, database, leaves, trie.NewTriePoseidon, db.ClassesTrie.Key()) + contractRoot = buildDeprecatedTrie(t, database, leaves, trie.NewTriePedersen, db.StateTrie.Key()) + + var ownerFelt felt.Felt + ownerFelt.SetUint64(42) + owner = felt.Address(ownerFelt) + ownerBytes := ownerFelt.Bytes() + storagePrefix := db.ContractStorage.Key(ownerBytes[:]) + storageRoot = buildDeprecatedTrie(t, database, leaves, trie.NewTriePedersen, storagePrefix) + + return database, classRoot, contractRoot, storageRoot, owner +} + +func insertFakeStorageNodes(t *testing.T, database db.KeyValueStore, owner felt.Address, n int) { + t.Helper() + ownerFelt := felt.Felt(owner) + ownerBytes := ownerFelt.Bytes() + ownerPrefix := db.ContractStorage.Key(ownerBytes[:]) + + require.NoError(t, database.Put(ownerPrefix, []byte{0})) + + for i := range n { + key := append(bytes.Clone(ownerPrefix), 8, byte(i)) + require.NoError(t, database.Put(key, []byte{0xFF})) + } +} + +func collectTries(t *testing.T, r db.KeyValueReader) []TrieDesc { + t.Helper() + descs, err := enumerateTries(r) + require.NoError(t, err) + return descs +} + +func TestEnumerateTries_EmptyDBYieldsClassAndContractTries(t *testing.T) { + testDB := memory.New() + descs := collectTries(t, testDB) + require.Len(t, descs, 2) + assert.Equal(t, db.ClassesTrie, descs[0].OldBucket) + assert.Equal(t, db.StateTrie, descs[1].OldBucket) +} + +func TestEnumerateTries_GlobalTriesPresent(t *testing.T) { + testDB := memory.New() + descs := collectTries(t, testDB) + hasClass := slices.ContainsFunc(descs, func(d TrieDesc) bool { + return d.OldBucket == db.ClassesTrie && d.NewBucket == db.ClassTrie + }) + hasContract := slices.ContainsFunc(descs, func(d TrieDesc) bool { + return d.OldBucket == db.StateTrie && d.NewBucket == db.ContractTrieContract + }) + assert.True(t, hasClass) + assert.True(t, hasContract) +} + +func TestEnumerateTries_StorageTriesDiscovered(t *testing.T) { + testDB := memory.New() + var owners [3]felt.Address + for i := range owners { + var f felt.Felt + f.SetUint64(uint64(i + 1)) + owners[i] = felt.Address(f) + insertFakeStorageNodes(t, testDB, owners[i], 5) + } + + descs := collectTries(t, testDB) + require.Len(t, descs, 5) + + storageCount := 0 + for _, d := range descs { + if d.OldBucket == db.ContractStorage { + assert.Equal(t, db.ContractTrieStorage, d.NewBucket) + storageCount++ + } + } + assert.Equal(t, 3, storageCount) +} + +func TestEnumerateTries_NodeCountIsCorrect(t *testing.T) { + testDB := memory.New() + var ownerFelt felt.Felt + ownerFelt.SetUint64(99) + owner := felt.Address(ownerFelt) + insertFakeStorageNodes(t, testDB, owner, 7) + + descs := collectTries(t, testDB) + require.Len(t, descs, 3) + + idx := slices.IndexFunc(descs, func(d TrieDesc) bool { return d.OldBucket == db.ContractStorage }) + require.NotEqual(t, -1, idx) + assert.Equal(t, 7, descs[idx].NodeCount) +} + +func TestEnumerateTries_StorageTrieCountsPresent(t *testing.T) { + testDB := memory.New() + for i, n := range []int{3, 7, 1} { + var f felt.Felt + f.SetUint64(uint64(i + 1)) + insertFakeStorageNodes(t, testDB, felt.Address(f), n) + } + + descs := collectTries(t, testDB) + var storageCounts []int + for _, d := range descs { + if d.OldBucket == db.ContractStorage { + storageCounts = append(storageCounts, d.NodeCount) + } + } + slices.Sort(storageCounts) + assert.Equal(t, []int{1, 3, 7}, storageCounts) +} + +func TestEnumerateTries_StorageTrieOwnerMatchesKey(t *testing.T) { + testDB := memory.New() + var ownerFelt felt.Felt + ownerFelt.SetUint64(12345) + owner := felt.Address(ownerFelt) + insertFakeStorageNodes(t, testDB, owner, 3) + + descs := collectTries(t, testDB) + require.Len(t, descs, 3) + + idx := slices.IndexFunc(descs, func(d TrieDesc) bool { return d.OldBucket == db.ContractStorage }) + require.NotEqual(t, -1, idx) + assert.Equal(t, owner, descs[idx].Owner) +} + +func TestEnumerateTries_MultipleOwnersOrdered(t *testing.T) { + testDB := memory.New() + owners := make([]felt.Address, 5) + for i := range owners { + var f felt.Felt + f.SetUint64(uint64(i + 1)) + owners[i] = felt.Address(f) + insertFakeStorageNodes(t, testDB, owners[i], i+1) + } + + descs := collectTries(t, testDB) + require.Len(t, descs, 7) + + var storageCounts []int + for _, d := range descs { + if d.OldBucket == db.ContractStorage { + storageCounts = append(storageCounts, d.NodeCount) + } + } + slices.Sort(storageCounts) + assert.Equal(t, []int{1, 2, 3, 4, 5}, storageCounts) +} diff --git a/node/migration.go b/node/migration.go index 57a53e9eac..2ca79e2332 100644 --- a/node/migration.go +++ b/node/migration.go @@ -14,6 +14,7 @@ import ( "github.com/NethermindEth/juno/migration/headstate" "github.com/NethermindEth/juno/migration/historyprunner" "github.com/NethermindEth/juno/migration/statehistory" + "github.com/NethermindEth/juno/migration/trie" "github.com/NethermindEth/juno/utils/log" ) @@ -29,7 +30,8 @@ func registerMigrations(cfg *Config) *migration.Registry { PruneModeFlag, ). WithOptional(&headstate.Migrator{}, cfg.NewState, "new-state"). - WithOptional(&statehistory.Migrator{}, cfg.NewState, "new-state") + WithOptional(&statehistory.Migrator{}, cfg.NewState, "new-state"). + WithOptional(&trie.Migrator{}, cfg.NewState, "new-state") return registry } From 3f52fa3e82b54ecfb83a327f98ff6e6b93e90a0a Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 20 May 2026 01:45:30 +0200 Subject: [PATCH 08/14] chore: linter --- migration/trie/codec.go | 10 -- migration/trie/codec_test.go | 9 +- migration/trie/counter.go | 2 +- migration/trie/dfsmigration.go | 110 ++++++++++++--------- migration/trie/dfsmigration_test.go | 37 +++---- migration/trie/hashworker.go | 39 ++------ migration/trie/hashworker_test.go | 31 +++--- migration/trie/ingestor.go | 61 ++++++------ migration/trie/trie.go | 145 ++++++++++++++++++---------- migration/trie/trie_test.go | 83 ++++++++-------- 10 files changed, 274 insertions(+), 253 deletions(-) diff --git a/migration/trie/codec.go b/migration/trie/codec.go index 77fb642efc..77e9264113 100644 --- a/migration/trie/codec.go +++ b/migration/trie/codec.go @@ -38,16 +38,6 @@ func encodeBinaryNode(leftEdgeHash, rightEdgeHash *felt.Felt) [binaryNodeBlobSiz return blob } -func encodeEdgeNode(childHash *felt.Felt, pathSeg *trieutils.Path) []byte { - encoded := pathSeg.EncodedBytes() - var arr [edgeNodeMaxSize]byte - arr[0] = edgeNodeTag - h := childHash.Bytes() - copy(arr[1:], h[:]) - copy(arr[1+felt.Bytes:], encoded) - return arr[:1+felt.Bytes+len(encoded)] -} - func encodeEdgeNodeInto(dst []byte, childHash *felt.Felt, pathSeg *trieutils.Path) int { encoded := pathSeg.EncodedBytes() dst[0] = edgeNodeTag diff --git a/migration/trie/codec_test.go b/migration/trie/codec_test.go index e468af7aa1..9a4f53722a 100644 --- a/migration/trie/codec_test.go +++ b/migration/trie/codec_test.go @@ -74,9 +74,12 @@ func TestEncodeEdgeNode(t *testing.T) { var childHash felt.Felt childHash.SetUint64(42) seg := makeNewPath(pathLen, 0b101) - blob := encodeEdgeNode(&childHash, &seg) - require.Greater(t, len(blob), 0) + var buf [edgeNodeMaxSize]byte + n := encodeEdgeNodeInto(buf[:], &childHash, &seg) + blob := buf[:n] + + require.Greater(t, n, 0) assert.Equal(t, edgeNodeTag, blob[0]) var got felt.Felt @@ -84,7 +87,7 @@ func TestEncodeEdgeNode(t *testing.T) { assert.Equal(t, childHash, got) activeBytes := (int(pathLen) + 7) / 8 - assert.Equal(t, 1+felt.Bytes+activeBytes+1, len(blob)) + assert.Equal(t, 1+felt.Bytes+activeBytes+1, n) }) } } diff --git a/migration/trie/counter.go b/migration/trie/counter.go index 2d64cb4b5a..58978cba12 100644 --- a/migration/trie/counter.go +++ b/migration/trie/counter.go @@ -30,7 +30,7 @@ func (c *counter) log(byteSize uint64, tries int) { now := time.Now() elapsed := now.Sub(c.start).Seconds() - if elapsed > float64(c.timeLogRate.Seconds()) { + if elapsed > c.timeLogRate.Seconds() { mbs := float64(c.size) / float64(db.Megabyte) c.logger.Info( "write speed", diff --git a/migration/trie/dfsmigration.go b/migration/trie/dfsmigration.go index 65e28421d7..4ea77dc3be 100644 --- a/migration/trie/dfsmigration.go +++ b/migration/trie/dfsmigration.go @@ -33,7 +33,7 @@ func (m *dfsMigrator) Migrate( rootPath := desc.RootPath sched := newHashScheduler(desc.HashFn, m.parallelDispatch, - desc.NewBucket, &desc.Owner, m.pool) + desc.NewBucket, desc.Owner, m.pool) rootHash, batch, stack, err := traverse(r, prefix, rootPath, sched, batch, flush, stack) if err != nil { @@ -53,20 +53,26 @@ func (m *dfsMigrator) Migrate( type traverseStackState uint8 const ( - readNode traverseStackState = iota - leftSubtreeDone - rightSubtreeDone + readNodeState traverseStackState = iota + leftSubtreeDoneState + rightSubtreeDoneState ) type dfsFrame struct { - oldPath trie.BitArray - left trie.BitArray - right trie.BitArray - value felt.Felt - leftHash felt.Felt - isLeaf bool - state traverseStackState - binaryDepth uint8 + oldPath trie.BitArray + left trie.BitArray + right trie.BitArray + value felt.Felt + leftHash felt.Felt + isLeaf bool + state traverseStackState +} + +type parsedNode struct { + value felt.Felt + left trie.BitArray + right trie.BitArray + isLeaf bool } func traverse( @@ -84,12 +90,17 @@ func traverse( for len(stack) > 0 { top := &stack[len(stack)-1] switch top.state { - case readNode: - if err := parseNodeInto(r, prefix, top); err != nil { + case readNodeState: + parsed, err := readNode(r, prefix, &top.oldPath) + if err != nil { return felt.Felt{}, batch, stack, err } - newPath := toNewPath(&top.oldPath) + top.value = parsed.value + top.left = parsed.left + top.right = parsed.right + top.isLeaf = parsed.isLeaf if top.isLeaf { + newPath := toNewPath(&top.oldPath) if err := processLeaf(newPath, &top.value, sched, batch); err != nil { return felt.Felt{}, batch, stack, err } @@ -97,15 +108,15 @@ func traverse( stack = stack[:len(stack)-1] } else { left := top.left - top.state = leftSubtreeDone + top.state = leftSubtreeDoneState stack = pushFrame(stack, left) } - case leftSubtreeDone: + case leftSubtreeDoneState: top.leftHash = lastHash right := top.right - top.state = rightSubtreeDone + top.state = rightSubtreeDoneState stack = pushFrame(stack, right) - case rightSubtreeDone: + case rightSubtreeDoneState: newPath := toNewPath(&top.oldPath) if err := processBinary( newPath, @@ -121,7 +132,11 @@ func traverse( lastHash = top.value stack = stack[:len(stack)-1] } - batch = flush(batch) + var err error + batch, err = flush(batch) + if err != nil { + return felt.Felt{}, batch, stack, err + } } return lastHash, batch, stack, nil } @@ -144,34 +159,45 @@ func encodeOldPath(path *trie.BitArray, dst []byte) int { return int(activeBytes) + 1 } -func parseNodeInto(r db.KeyValueReader, prefix []byte, frame *dfsFrame) error { +// readNode loads the deprecated-format node at (prefix, oldPath) and returns +// its parsed fields. The caller assigns the parsed values into its own state +// (e.g. a DFS stack frame); this function does not mutate any input. +func readNode(r db.KeyValueReader, prefix []byte, oldPath *trie.BitArray) (parsedNode, error) { var arr [maxOldKeySize]byte n := copy(arr[:], prefix) - n += encodeOldPath(&frame.oldPath, arr[n:]) - return r.Get(arr[:n], func(val []byte) error { - return parseNodeData(val, &frame.value, &frame.left, &frame.right, &frame.isLeaf) + n += encodeOldPath(oldPath, arr[n:]) + var node parsedNode + err := r.Get(arr[:n], func(val []byte) error { + var perr error + node, perr = parseNodeData(val) + return perr }) + return node, err } -func parseNodeData(data []byte, value *felt.Felt, left, right *trie.BitArray, isLeaf *bool) error { +// parseNodeData decodes a deprecated-format node's raw bytes: +// felt(value) [ BitArray(left) BitArray(right) [ felt felt ] ] +// The trailing left/right hashes are ignored — the migrator re-derives hashes +// itself — so only the fields it actually needs are returned. +func parseNodeData(data []byte) (parsedNode, error) { + var n parsedNode if len(data) < felt.Bytes { - return fmt.Errorf("trie: node data too short (%d bytes)", len(data)) + return n, fmt.Errorf("trie: node data too short (%d bytes)", len(data)) } - *value = felt.FromBytes[felt.Felt](data[:felt.Bytes]) + n.value = felt.FromBytes[felt.Felt](data[:felt.Bytes]) data = data[felt.Bytes:] if len(data) == 0 { - *isLeaf = true - return nil + n.isLeaf = true + return n, nil } - *isLeaf = false - if err := left.UnmarshalBinary(data); err != nil { - return fmt.Errorf("trie: unmarshalling left path: %w", err) + if err := n.left.UnmarshalBinary(data); err != nil { + return n, fmt.Errorf("trie: unmarshalling left path: %w", err) } - data = data[left.EncodedLen():] - if err := right.UnmarshalBinary(data); err != nil { - return fmt.Errorf("trie: unmarshalling right path: %w", err) + data = data[n.left.EncodedLen():] + if err := n.right.UnmarshalBinary(data); err != nil { + return n, fmt.Errorf("trie: unmarshalling right path: %w", err) } - return nil + return n, nil } func processLeaf( @@ -196,12 +222,7 @@ func processBinary( ) error { leftSeg := compressedSegment(left, parentPath.Len()) rightSeg := compressedSegment(right, parentPath.Len()) - leftFull := toNewPath(left) - rightFull := toNewPath(right) - var leftEdgePath, rightEdgePath trieutils.Path - leftEdgePath.MSBs(&leftFull, parentPath.Len()+1) - rightEdgePath.MSBs(&rightFull, parentPath.Len()+1) - return sched.schedule(edgeHashJob{ + return sched.schedule(&edgeHashJob{ leftChildHash: leftChildHash, leftSeg: leftSeg, rightChildHash: rightChildHash, @@ -217,8 +238,9 @@ func writeRootEdge( batch db.Batch, ) error { seg := toNewPath(rootPath) - blob := encodeEdgeNode(&childHash, &seg) - return trieutils.WriteNodeByPath(batch, sched.bucket, &sched.owner, &trieutils.Path{}, false, blob) + var buf [edgeNodeMaxSize]byte + n := encodeEdgeNodeInto(buf[:], &childHash, &seg) + return trieutils.WriteNodeByPath(batch, sched.bucket, &sched.owner, &trieutils.Path{}, false, buf[:n]) } func oldTriePrefix(desc TrieDesc) []byte { diff --git a/migration/trie/dfsmigration_test.go b/migration/trie/dfsmigration_test.go index 2ff05a252e..79e220de27 100644 --- a/migration/trie/dfsmigration_test.go +++ b/migration/trie/dfsmigration_test.go @@ -21,8 +21,8 @@ import ( type leafMap map[felt.Felt]felt.Felt -func nopLogger() log.StructuredLogger { return log.NewNopZapLogger() } -func noFlush(current db.Batch) db.Batch { return current } +func nopLogger() log.StructuredLogger { return log.NewNopZapLogger() } +func noFlush(current db.Batch) (db.Batch, error) { return current, nil } type trieCase struct { name string @@ -36,14 +36,6 @@ type trieCase struct { buildOldFn func(db.IndexedBatch, []byte, uint8) (*trie.Trie, error) } -var allBackends = []struct { - name string - backend *dfsMigrator -}{ - {"dfsSerial", &dfsMigrator{parallelDispatch: false}}, - {"dfsParallel", &dfsMigrator{parallelDispatch: true}}, -} - var trieCases = []trieCase{ { name: "ClassTrie", @@ -160,23 +152,19 @@ var transcoderCases = []struct { // counts. This catches encoding bugs that root-hash comparison cannot detect. func TestMigrationEndToEnd(t *testing.T) { type testCase struct { - name string - tc trieCase - backend *dfsMigrator - leaves leafMap + name string + tc trieCase + leaves leafMap } var cases []testCase for _, tc := range trieCases { - for _, b := range allBackends { - for _, lc := range transcoderCases { - cases = append(cases, testCase{ - name: tc.name + "/" + b.name + "/" + lc.name, - tc: tc, - backend: b.backend, - leaves: lc.leaves, - }) - } + for _, lc := range transcoderCases { + cases = append(cases, testCase{ + name: tc.name + "/" + lc.name, + tc: tc, + leaves: lc.leaves, + }) } } @@ -186,7 +174,8 @@ func TestMigrationEndToEnd(t *testing.T) { migratedDB := memory.New() buildDeprecatedTrie(t, migratedDB, c.leaves, c.tc.buildOldFn, prefix) - runMigration(context.Background(), migratedDB, nopLogger()) + _, err := runMigration(context.Background(), migratedDB, nopLogger()) + require.NoError(t, err) nativeDB := memory.New() buildTrie(t, nativeDB, c.leaves, diff --git a/migration/trie/hashworker.go b/migration/trie/hashworker.go index 6cdbf3d28e..f0ed866b16 100644 --- a/migration/trie/hashworker.go +++ b/migration/trie/hashworker.go @@ -37,14 +37,14 @@ func newHashScheduler( hashFn crypto.HashFn, parallel bool, bucket db.Bucket, - owner *felt.Address, + owner felt.Address, pool *hashWorkerPool, ) *hashScheduler { s := &hashScheduler{ hashFn: hashFn, parallel: parallel, bucket: bucket, - owner: *owner, + owner: owner, pool: pool, } if parallel { @@ -55,34 +55,13 @@ func newHashScheduler( return s } -func (s *hashScheduler) reset( - hashFn crypto.HashFn, - parallel bool, - bucket db.Bucket, - owner *felt.Address, -) { - s.hashFn = hashFn - s.parallel = parallel - s.bucket = bucket - s.owner = *owner - s.hasInFlight = false - if parallel && s.jobs == nil { - s.jobs = make([]edgeHashJob, 0, parallelHashBatchSize) - s.altJobs = make([]edgeHashJob, 0, parallelHashBatchSize) - s.results = make([]felt.Felt, 2*parallelHashBatchSize) - } else if parallel { - s.jobs = s.jobs[:0] - s.altJobs = s.altJobs[:0] - } -} - -func (s *hashScheduler) schedule(job edgeHashJob, batch db.Batch) error { +func (s *hashScheduler) schedule(job *edgeHashJob, batch db.Batch) error { if !s.parallel { leftEdge := computeEdgeHash(&job.leftChildHash, &job.leftSeg, s.hashFn) rightEdge := computeEdgeHash(&job.rightChildHash, &job.rightSeg, s.hashFn) return s.writeBinaryAndEdges(job, &leftEdge, &rightEdge, batch) } - s.jobs = append(s.jobs, job) + s.jobs = append(s.jobs, *job) if len(s.jobs) >= parallelHashBatchSize { return s.fire(batch) } @@ -109,9 +88,9 @@ func (s *hashScheduler) drainInFlight(batch db.Batch) error { return nil } <-s.inFlightBuf.done - for i, job := range s.inFlightBuf.jobs { + for i := range s.inFlightBuf.jobs { err := s.writeBinaryAndEdges( - job, + &s.inFlightBuf.jobs[i], &s.inFlightBuf.results[2*i], &s.inFlightBuf.results[2*i+1], batch, @@ -134,9 +113,9 @@ func (s *hashScheduler) sync(batch db.Batch) error { if len(s.jobs) > 0 { results := s.results[:2*len(s.jobs)] <-s.pool.submit(s.hashFn, s.jobs, results) - for i, job := range s.jobs { + for i := range s.jobs { if err := s.writeBinaryAndEdges( - job, + &s.jobs[i], &results[2*i], &results[2*i+1], batch, @@ -150,7 +129,7 @@ func (s *hashScheduler) sync(batch db.Batch) error { } func (s *hashScheduler) writeBinaryAndEdges( - job edgeHashJob, + job *edgeHashJob, leftEdge, rightEdge *felt.Felt, batch db.Batch, diff --git a/migration/trie/hashworker_test.go b/migration/trie/hashworker_test.go index 7c02a47113..0c89893f76 100644 --- a/migration/trie/hashworker_test.go +++ b/migration/trie/hashworker_test.go @@ -12,23 +12,19 @@ import ( "github.com/stretchr/testify/require" ) -func makeSimpleJob(parent *trieutils.Path, leftChild, rightChild *felt.Felt) edgeHashJob { - return edgeHashJob{ +func makeSimpleJob(parent *trieutils.Path, leftChild, rightChild *felt.Felt) *edgeHashJob { + return &edgeHashJob{ leftChildHash: *leftChild, rightChildHash: *rightChild, parentPath: *parent, } } -func expectedEdgeHash(child *felt.Felt, seg *trieutils.Path, hashFn crypto.HashFn) felt.Felt { - return computeEdgeHash(child, seg, hashFn) -} - func TestHashScheduler_InlineWritesNodeToBatch(t *testing.T) { memDB := memory.New() batch := memDB.NewBatch() - sched := newHashScheduler(crypto.Pedersen, false, db.ClassTrie, &felt.Address{}, nil) + sched := newHashScheduler(crypto.Pedersen, false, db.ClassTrie, felt.Address{}, nil) parent := makeNewPath(3, 0b101) left := felt.NewFromUint64[felt.Felt](1) right := felt.NewFromUint64[felt.Felt](2) @@ -61,15 +57,15 @@ func testBatchedMatchesInline(t *testing.T, hashFn crypto.HashFn, n int) { l := felt.NewFromUint64[felt.Felt](uint64(i*2 + 1)) r := felt.NewFromUint64[felt.Felt](uint64(i*2 + 2)) path := makeNewPath(uint8(i%250), uint64(i)) - jobs[i] = makeSimpleJob(&path, l, r) + jobs[i] = *makeSimpleJob(&path, l, r) } inlineDB := memory.New() { batch := inlineDB.NewBatch() - sched := newHashScheduler(hashFn, false, db.ClassTrie, &felt.Address{}, nil) - for _, job := range jobs { - require.NoError(t, sched.schedule(job, batch)) + sched := newHashScheduler(hashFn, false, db.ClassTrie, felt.Address{}, nil) + for i := range jobs { + require.NoError(t, sched.schedule(&jobs[i], batch)) } require.NoError(t, batch.Write()) } @@ -77,15 +73,16 @@ func testBatchedMatchesInline(t *testing.T, hashFn crypto.HashFn, n int) { batchDB := memory.New() { batch := batchDB.NewBatch() - sched := newHashScheduler(hashFn, true, db.ClassTrie, &felt.Address{}, pool) - for _, job := range jobs { - require.NoError(t, sched.schedule(job, batch)) + sched := newHashScheduler(hashFn, true, db.ClassTrie, felt.Address{}, pool) + for i := range jobs { + require.NoError(t, sched.schedule(&jobs[i], batch)) } require.NoError(t, sched.sync(batch)) require.NoError(t, batch.Write()) } - for i, job := range jobs { + for i := range jobs { + job := &jobs[i] inlineVal, err := trieutils.GetNodeByPath( inlineDB, db.ClassTrie, @@ -113,7 +110,7 @@ func TestHashScheduler_AutoFlushAtBatchSize(t *testing.T) { memDB := memory.New() batch := memDB.NewBatch() - sched := newHashScheduler(crypto.Pedersen, true, db.ClassTrie, &felt.Address{}, pool) + sched := newHashScheduler(crypto.Pedersen, true, db.ClassTrie, felt.Address{}, pool) for i := range parallelHashBatchSize { path := makeNewPath(251, uint64(i+1)) leftChildHash := felt.NewFromUint64[felt.Felt](uint64(i + 1)) @@ -134,7 +131,7 @@ func TestHashScheduler_SingleJobDispatchesCorrectly(t *testing.T) { memDB := memory.New() batch := memDB.NewBatch() - sched := newHashScheduler(crypto.Pedersen, true, db.ClassTrie, &felt.Address{}, pool) + sched := newHashScheduler(crypto.Pedersen, true, db.ClassTrie, felt.Address{}, pool) parent := makeNewPath(4, 0b1010) l, r := felt.NewFromUint64[felt.Felt](7), felt.NewFromUint64[felt.Felt](13) job := makeSimpleJob(&parent, l, r) diff --git a/migration/trie/ingestor.go b/migration/trie/ingestor.go index 1700eea0f3..e93e49dbbc 100644 --- a/migration/trie/ingestor.go +++ b/migration/trie/ingestor.go @@ -6,7 +6,6 @@ import ( "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/migration/semaphore" - "github.com/NethermindEth/juno/utils/log" ) type task struct { @@ -17,43 +16,37 @@ type task struct { const dfsStackCap = 251 type ingestor struct { - logger log.StructuredLogger + ctx context.Context database db.KeyValueReader batchSemaphore semaphore.ResourceSemaphore[db.Batch] - tasks []task pool *hashWorkerPool + tasks [IngestorCount]task dfsStacks [IngestorCount][]dfsFrame } -type FlushBatchFn func(db.Batch) db.Batch +type FlushBatchFn func(db.Batch) (db.Batch, error) func newIngestor( ctx context.Context, database db.KeyValueReader, batchSemaphore semaphore.ResourceSemaphore[db.Batch], - logger log.StructuredLogger, pool *hashWorkerPool, -) (*ingestor, error) { - tasks := make([]task, IngestorCount) - for i := range tasks { - tasks[i] = task{batch: batchSemaphore.GetBlocking()} - } - +) *ingestor { in := &ingestor{ + ctx: ctx, database: database, batchSemaphore: batchSemaphore, - logger: logger, - tasks: tasks, pool: pool, } - for i := range in.dfsStacks { + for i := range IngestorCount { + in.tasks[i].batch = batchSemaphore.GetBlocking() in.dfsStacks[i] = make([]dfsFrame, 0, dfsStackCap) } - return in, nil + return in } -func (c *ingestor) Run(index int, desc TrieDesc, outputs chan<- task) error { - done, err := hasDestRoot(c.database, desc.NewBucket, &desc.Owner) +func (i *ingestor) Run(index int, desc TrieDesc, outputs chan<- task) error { + done, err := hasDestRoot(i.database, desc.NewBucket, &desc.Owner) if err != nil { return fmt.Errorf("hasDestRoot(%v, %x): %w", desc.NewBucket, desc.Owner, err) } @@ -61,35 +54,43 @@ func (c *ingestor) Run(index int, desc TrieDesc, outputs chan<- task) error { return nil } - t := &c.tasks[index] + t := &i.tasks[index] - flush := FlushBatchFn(func(current db.Batch) db.Batch { + flush := FlushBatchFn(func(current db.Batch) (db.Batch, error) { if current.Size() < targetBatchByteSize { - return current + return current, nil + } + select { + case <-i.ctx.Done(): + return current, i.ctx.Err() + case outputs <- task{batch: current, tries: t.tries}: } - outputs <- task{batch: current, tries: t.tries} t.tries = 0 - return c.batchSemaphore.GetBlocking() + return i.batchSemaphore.GetBlocking(), nil }) - migrator := newDFSMigrator(desc.NodeCount >= SmallTrieThreshold, c.pool) - t.batch, c.dfsStacks[index], err = migrator.Migrate( - c.database, + migrator := newDFSMigrator(desc.NodeCount >= SmallTrieThreshold, i.pool) + t.batch, i.dfsStacks[index], err = migrator.Migrate( + i.database, t.batch, desc, flush, - c.dfsStacks[index], + i.dfsStacks[index], ) if err != nil { return err } t.tries++ - t.batch = flush(t.batch) - return nil + t.batch, err = flush(t.batch) + return err } -func (c *ingestor) Done(index int, outputs chan<- task) error { - outputs <- c.tasks[index] +func (i *ingestor) Done(index int, outputs chan<- task) error { + select { + case <-i.ctx.Done(): + return i.ctx.Err() + case outputs <- i.tasks[index]: + } return nil } diff --git a/migration/trie/trie.go b/migration/trie/trie.go index e35ece9dba..9b00e55ea8 100644 --- a/migration/trie/trie.go +++ b/migration/trie/trie.go @@ -3,6 +3,7 @@ package trie import ( "bytes" "context" + "errors" "fmt" "time" @@ -33,6 +34,8 @@ var ( shouldNotRerun []byte ) +var deprecatedTrieBuckets = []db.Bucket{db.ClassesTrie, db.StateTrie, db.ContractStorage} + type Migrator struct{} var _ migration.Migration = (*Migrator)(nil) @@ -58,7 +61,7 @@ func (*Migrator) Migrate( } func needsMigration(r db.KeyValueReader) (bool, error) { - for _, bucket := range []db.Bucket{db.ClassesTrie, db.StateTrie, db.ContractStorage} { + for _, bucket := range deprecatedTrieBuckets { prefix := bucket.Key() iter, err := r.NewIterator(prefix, true) if err != nil { @@ -85,11 +88,9 @@ func runMigration( }) pool := newHashWorkerPool() + defer pool.close() - ing, err := newIngestor(ctx, database, batchSem, logger, pool) - if err != nil { - return shouldRerun, err - } + ing := newIngestor(ctx, database, batchSem, pool) tries, err := enumerateTries(database) if err != nil { @@ -112,23 +113,33 @@ func runMigration( _, wait := committed.Run(ctx) res := wait() - if !res.IsDone { - pool.close() + if res.Err != nil { return shouldRerun, res.Err } + if !res.IsDone { + if ctxErr := ctx.Err(); ctxErr != nil { + return shouldRerun, ctxErr + } + return shouldRerun, errors.New("trie migration: pipeline did not complete") + } + + if err := wipeDeprecatedBuckets(database); err != nil { + return shouldRerun, err + } + logger.Info("trie migration: source buckets deleted") - pool.close() + return shouldNotRerun, nil +} - for _, bucket := range []db.Bucket{db.ClassesTrie, db.StateTrie, db.ContractStorage} { +func wipeDeprecatedBuckets(database db.KeyValueRangeDeleter) error { + for _, bucket := range deprecatedTrieBuckets { start := bucket.Key() end := dbutils.UpperBound(start) if err := database.DeleteRange(start, end); err != nil { - return shouldRerun, fmt.Errorf("trie migration: cleanup DeleteRange for %v: %w", bucket, err) + return fmt.Errorf("trie migration: cleanup DeleteRange for %v: %w", bucket, err) } } - logger.Info("trie migration: source buckets deleted") - - return shouldNotRerun, nil + return nil } type TrieDesc struct { @@ -150,42 +161,74 @@ func enumerateTries(r db.KeyValueReader) ([]TrieDesc, error) { {db.ClassesTrie, db.ClassTrie, crypto.Poseidon}, {db.StateTrie, db.ContractTrieContract, crypto.Pedersen}, } { - prefix := spec.oldBucket.Key() - it, err := r.NewIterator(prefix, true) + desc, err := enumerateGlobalTrie(r, spec.oldBucket, spec.newBucket, spec.hashFn) if err != nil { - return nil, fmt.Errorf("opening iterator for bucket %v: %w", spec.oldBucket, err) + return nil, err } - var rootPath *trie.BitArray - count := 0 - for valid := it.First(); valid; valid = it.Next() { - key := it.Key() - if len(key) == len(prefix) { - if val, verr := it.Value(); verr == nil { - rootPath = parseRootPath(val) - } - } else { - count++ + descs = append(descs, desc) + } + + storageDescs, err := enumerateStorageTries(r) + if err != nil { + return nil, err + } + descs = append(descs, storageDescs...) + + return descs, nil +} + +func enumerateGlobalTrie( + r db.KeyValueReader, + oldBucket, newBucket db.Bucket, + hashFn crypto.HashFn, +) (TrieDesc, error) { + prefix := oldBucket.Key() + it, err := r.NewIterator(prefix, true) + if err != nil { + return TrieDesc{}, fmt.Errorf("opening iterator for bucket %v: %w", oldBucket, err) + } + defer it.Close() + + var rootPath *trie.BitArray + count := 0 + for valid := it.First(); valid; valid = it.Next() { + key := it.Key() + if len(key) == len(prefix) { + val, verr := it.Value() + if verr != nil { + return TrieDesc{}, fmt.Errorf("reading root path for bucket %v: %w", oldBucket, verr) } + rp, perr := parseRootPath(val) + if perr != nil { + return TrieDesc{}, fmt.Errorf("parsing root path for bucket %v: %w", oldBucket, perr) + } + rootPath = rp + } else { + count++ } - it.Close() - descs = append(descs, TrieDesc{ - OldBucket: spec.oldBucket, - NewBucket: spec.newBucket, - HashFn: spec.hashFn, - NodeCount: count, - RootPath: rootPath, - }) } + return TrieDesc{ + OldBucket: oldBucket, + NewBucket: newBucket, + HashFn: hashFn, + NodeCount: count, + RootPath: rootPath, + }, nil +} +func enumerateStorageTries(r db.KeyValueReader) ([]TrieDesc, error) { storagePrefix := db.ContractStorage.Key() - storageIter, err := r.NewIterator(storagePrefix, true) + it, err := r.NewIterator(storagePrefix, true) if err != nil { return nil, fmt.Errorf("opening storage iterator: %w", err) } - for valid := storageIter.First(); valid; valid = storageIter.Valid() { - key := storageIter.Key() + defer it.Close() + + var descs []TrieDesc + for valid := it.First(); valid; valid = it.Valid() { + key := it.Key() if len(key) < 1+felt.Bytes { - storageIter.Next() + it.Next() continue } ownerFelt := felt.FromBytes[felt.Felt](key[1 : 1+felt.Bytes]) @@ -195,19 +238,25 @@ func enumerateTries(r db.KeyValueReader) ([]TrieDesc, error) { var rootPath *trie.BitArray count := 0 - for storageIter.Valid() { - k := storageIter.Key() + for it.Valid() { + k := it.Key() if !bytes.HasPrefix(k, ownerPrefix) { break } if len(k) == len(ownerPrefix) { - if val, verr := storageIter.Value(); verr == nil { - rootPath = parseRootPath(val) + val, verr := it.Value() + if verr != nil { + return nil, fmt.Errorf("reading root path for storage owner %s: %w", &ownerFelt, verr) } + rp, perr := parseRootPath(val) + if perr != nil { + return nil, fmt.Errorf("parsing root path for storage owner %s: %w", &ownerFelt, perr) + } + rootPath = rp } else { count++ } - storageIter.Next() + it.Next() } descs = append(descs, TrieDesc{ OldBucket: db.ContractStorage, @@ -218,20 +267,18 @@ func enumerateTries(r db.KeyValueReader) ([]TrieDesc, error) { RootPath: rootPath, }) } - storageIter.Close() - return descs, nil } -func parseRootPath(val []byte) *trie.BitArray { +func parseRootPath(val []byte) (*trie.BitArray, error) { if len(val) == 0 { - return nil + return nil, nil } var ba trie.BitArray if err := ba.UnmarshalBinary(val); err != nil { - return nil + return nil, err } - return &ba + return &ba, nil } func hasDestRoot(r db.KeyValueReader, newBucket db.Bucket, owner *felt.Address) (bool, error) { diff --git a/migration/trie/trie_test.go b/migration/trie/trie_test.go index d3c5322cf9..630bed62fe 100644 --- a/migration/trie/trie_test.go +++ b/migration/trie/trie_test.go @@ -16,10 +16,9 @@ import ( ) func TestMigrate_FreshDBIsNoOp(t *testing.T) { - testDB := memory.New() - t.Cleanup(func() { testDB.Close() }) + memDB := memory.New() - state, err := (&Migrator{}).Migrate(context.Background(), testDB, nil, nopLogger()) + state, err := (&Migrator{}).Migrate(context.Background(), memDB, nil, nopLogger()) require.NoError(t, err) assert.Nil( t, @@ -30,17 +29,17 @@ func TestMigrate_FreshDBIsNoOp(t *testing.T) { func TestMigrate_RunsWhenOldDataPresent(t *testing.T) { leaves := randomLeaves(100, 7) - testDB, _, _, _, _ := buildFullDB(t, leaves) + memDB := buildFullDB(t, leaves) - needed, err := needsMigration(testDB) + needed, err := needsMigration(memDB) require.NoError(t, err) require.True(t, needed, "precondition: DB has old-format data") - state, err := (&Migrator{}).Migrate(context.Background(), testDB, nil, nopLogger()) + state, err := (&Migrator{}).Migrate(context.Background(), memDB, nil, nopLogger()) require.NoError(t, err) assert.Nil(t, state, "completed migration must return nil intermediate state") - stillNeeded, err := needsMigration(testDB) + stillNeeded, err := needsMigration(memDB) require.NoError(t, err) assert.False(t, stillNeeded, "old-format buckets should be empty after migration") } @@ -49,19 +48,20 @@ func TestMigrationIsResumable(t *testing.T) { leaves := randomLeaves(1000, 42) // Reference: full migration from scratch. - refDB, _, _, _, _ := buildFullDB(t, leaves) + refDB := buildFullDB(t, leaves) _, err := runMigration(context.Background(), refDB, nopLogger()) require.NoError(t, err) // Partial DB: both tries in old format initially. - partialDB, _, _, _, _ := buildFullDB(t, leaves) + partialDB := buildFullDB(t, leaves) // Manually migrate only the class trie to simulate a mid-run interruption. classPrefix := db.ClassesTrie.Key() var classRootPath *trie.BitArray require.NoError(t, partialDB.Get(classPrefix, func(val []byte) error { - classRootPath = parseRootPath(val) - return nil + var perr error + classRootPath, perr = parseRootPath(val) + return perr })) classDesc := TrieDesc{ OldBucket: db.ClassesTrie, @@ -118,29 +118,22 @@ func TestMigrationIsResumable(t *testing.T) { } } -// buildFullDB creates an old-format DB with class, contract, and one storage trie, -// all populated with the same leaf set. Returns the DB and the old-format root hashes. -func buildFullDB(t *testing.T, leaves leafMap) ( - database db.KeyValueStore, - classRoot felt.Felt, - contractRoot felt.Felt, - storageRoot felt.Felt, - owner felt.Address, -) { +// buildFullDB creates an old-format DB populated with a class, a contract, and +// one storage trie, all built from the same leaf set. +func buildFullDB(t *testing.T, leaves leafMap) db.KeyValueStore { t.Helper() - database = memory.New() + database := memory.New() - classRoot = buildDeprecatedTrie(t, database, leaves, trie.NewTriePoseidon, db.ClassesTrie.Key()) - contractRoot = buildDeprecatedTrie(t, database, leaves, trie.NewTriePedersen, db.StateTrie.Key()) + buildDeprecatedTrie(t, database, leaves, trie.NewTriePoseidon, db.ClassesTrie.Key()) + buildDeprecatedTrie(t, database, leaves, trie.NewTriePedersen, db.StateTrie.Key()) var ownerFelt felt.Felt ownerFelt.SetUint64(42) - owner = felt.Address(ownerFelt) ownerBytes := ownerFelt.Bytes() storagePrefix := db.ContractStorage.Key(ownerBytes[:]) - storageRoot = buildDeprecatedTrie(t, database, leaves, trie.NewTriePedersen, storagePrefix) + buildDeprecatedTrie(t, database, leaves, trie.NewTriePedersen, storagePrefix) - return database, classRoot, contractRoot, storageRoot, owner + return database } func insertFakeStorageNodes(t *testing.T, database db.KeyValueStore, owner felt.Address, n int) { @@ -165,16 +158,16 @@ func collectTries(t *testing.T, r db.KeyValueReader) []TrieDesc { } func TestEnumerateTries_EmptyDBYieldsClassAndContractTries(t *testing.T) { - testDB := memory.New() - descs := collectTries(t, testDB) + memDB := memory.New() + descs := collectTries(t, memDB) require.Len(t, descs, 2) assert.Equal(t, db.ClassesTrie, descs[0].OldBucket) assert.Equal(t, db.StateTrie, descs[1].OldBucket) } func TestEnumerateTries_GlobalTriesPresent(t *testing.T) { - testDB := memory.New() - descs := collectTries(t, testDB) + memDB := memory.New() + descs := collectTries(t, memDB) hasClass := slices.ContainsFunc(descs, func(d TrieDesc) bool { return d.OldBucket == db.ClassesTrie && d.NewBucket == db.ClassTrie }) @@ -186,16 +179,16 @@ func TestEnumerateTries_GlobalTriesPresent(t *testing.T) { } func TestEnumerateTries_StorageTriesDiscovered(t *testing.T) { - testDB := memory.New() + memDB := memory.New() var owners [3]felt.Address for i := range owners { var f felt.Felt f.SetUint64(uint64(i + 1)) owners[i] = felt.Address(f) - insertFakeStorageNodes(t, testDB, owners[i], 5) + insertFakeStorageNodes(t, memDB, owners[i], 5) } - descs := collectTries(t, testDB) + descs := collectTries(t, memDB) require.Len(t, descs, 5) storageCount := 0 @@ -209,13 +202,13 @@ func TestEnumerateTries_StorageTriesDiscovered(t *testing.T) { } func TestEnumerateTries_NodeCountIsCorrect(t *testing.T) { - testDB := memory.New() + memDB := memory.New() var ownerFelt felt.Felt ownerFelt.SetUint64(99) owner := felt.Address(ownerFelt) - insertFakeStorageNodes(t, testDB, owner, 7) + insertFakeStorageNodes(t, memDB, owner, 7) - descs := collectTries(t, testDB) + descs := collectTries(t, memDB) require.Len(t, descs, 3) idx := slices.IndexFunc(descs, func(d TrieDesc) bool { return d.OldBucket == db.ContractStorage }) @@ -224,14 +217,14 @@ func TestEnumerateTries_NodeCountIsCorrect(t *testing.T) { } func TestEnumerateTries_StorageTrieCountsPresent(t *testing.T) { - testDB := memory.New() + memDB := memory.New() for i, n := range []int{3, 7, 1} { var f felt.Felt f.SetUint64(uint64(i + 1)) - insertFakeStorageNodes(t, testDB, felt.Address(f), n) + insertFakeStorageNodes(t, memDB, felt.Address(f), n) } - descs := collectTries(t, testDB) + descs := collectTries(t, memDB) var storageCounts []int for _, d := range descs { if d.OldBucket == db.ContractStorage { @@ -243,13 +236,13 @@ func TestEnumerateTries_StorageTrieCountsPresent(t *testing.T) { } func TestEnumerateTries_StorageTrieOwnerMatchesKey(t *testing.T) { - testDB := memory.New() + memDB := memory.New() var ownerFelt felt.Felt ownerFelt.SetUint64(12345) owner := felt.Address(ownerFelt) - insertFakeStorageNodes(t, testDB, owner, 3) + insertFakeStorageNodes(t, memDB, owner, 3) - descs := collectTries(t, testDB) + descs := collectTries(t, memDB) require.Len(t, descs, 3) idx := slices.IndexFunc(descs, func(d TrieDesc) bool { return d.OldBucket == db.ContractStorage }) @@ -258,16 +251,16 @@ func TestEnumerateTries_StorageTrieOwnerMatchesKey(t *testing.T) { } func TestEnumerateTries_MultipleOwnersOrdered(t *testing.T) { - testDB := memory.New() + memDB := memory.New() owners := make([]felt.Address, 5) for i := range owners { var f felt.Felt f.SetUint64(uint64(i + 1)) owners[i] = felt.Address(f) - insertFakeStorageNodes(t, testDB, owners[i], i+1) + insertFakeStorageNodes(t, memDB, owners[i], i+1) } - descs := collectTries(t, testDB) + descs := collectTries(t, memDB) require.Len(t, descs, 7) var storageCounts []int From bfe3da24cda45f2ec2b6412ccb4e261275450e7d Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 20 May 2026 11:50:24 +0200 Subject: [PATCH 09/14] chore: reorganise the files --- migration/trie/dfsmigration.go | 253 ------------------------ migration/trie/dfsmigration_test.go | 262 ------------------------- migration/trie/ingestor.go | 293 +++++++++++++++++++++++++--- migration/trie/trie_test.go | 255 +++++++++++++++++++++++- 4 files changed, 518 insertions(+), 545 deletions(-) delete mode 100644 migration/trie/dfsmigration.go delete mode 100644 migration/trie/dfsmigration_test.go diff --git a/migration/trie/dfsmigration.go b/migration/trie/dfsmigration.go deleted file mode 100644 index 4ea77dc3be..0000000000 --- a/migration/trie/dfsmigration.go +++ /dev/null @@ -1,253 +0,0 @@ -package trie - -import ( - "fmt" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" - "github.com/NethermindEth/juno/core/trie2/trieutils" - "github.com/NethermindEth/juno/db" -) - -type dfsMigrator struct { - parallelDispatch bool - pool *hashWorkerPool -} - -func newDFSMigrator(parallelDispatch bool, pool *hashWorkerPool) *dfsMigrator { - return &dfsMigrator{parallelDispatch: parallelDispatch, pool: pool} -} - -func (m *dfsMigrator) Migrate( - r db.KeyValueReader, - batch db.Batch, - desc TrieDesc, - flush FlushBatchFn, - stack []dfsFrame, -) (db.Batch, []dfsFrame, error) { - stack = stack[:0] - if desc.RootPath == nil { - return batch, stack, nil - } - prefix := oldTriePrefix(desc) - rootPath := desc.RootPath - sched := newHashScheduler(desc.HashFn, - m.parallelDispatch, - desc.NewBucket, desc.Owner, m.pool) - - rootHash, batch, stack, err := traverse(r, prefix, rootPath, sched, batch, flush, stack) - if err != nil { - return batch, stack, err - } - if err = sched.sync(batch); err != nil { - return batch, stack, err - } - if rootPath.Len() > 0 { - if err = writeRootEdge(rootPath, rootHash, sched, batch); err != nil { - return batch, stack, err - } - } - return batch, stack, nil -} - -type traverseStackState uint8 - -const ( - readNodeState traverseStackState = iota - leftSubtreeDoneState - rightSubtreeDoneState -) - -type dfsFrame struct { - oldPath trie.BitArray - left trie.BitArray - right trie.BitArray - value felt.Felt - leftHash felt.Felt - isLeaf bool - state traverseStackState -} - -type parsedNode struct { - value felt.Felt - left trie.BitArray - right trie.BitArray - isLeaf bool -} - -func traverse( - r db.KeyValueReader, - prefix []byte, - start *trie.BitArray, - sched *hashScheduler, - batch db.Batch, - flush FlushBatchFn, - stack []dfsFrame, -) (felt.Felt, db.Batch, []dfsFrame, error) { - stack = stack[:1] - stack[0] = dfsFrame{oldPath: *start} - var lastHash felt.Felt - for len(stack) > 0 { - top := &stack[len(stack)-1] - switch top.state { - case readNodeState: - parsed, err := readNode(r, prefix, &top.oldPath) - if err != nil { - return felt.Felt{}, batch, stack, err - } - top.value = parsed.value - top.left = parsed.left - top.right = parsed.right - top.isLeaf = parsed.isLeaf - if top.isLeaf { - newPath := toNewPath(&top.oldPath) - if err := processLeaf(newPath, &top.value, sched, batch); err != nil { - return felt.Felt{}, batch, stack, err - } - lastHash = top.value - stack = stack[:len(stack)-1] - } else { - left := top.left - top.state = leftSubtreeDoneState - stack = pushFrame(stack, left) - } - case leftSubtreeDoneState: - top.leftHash = lastHash - right := top.right - top.state = rightSubtreeDoneState - stack = pushFrame(stack, right) - case rightSubtreeDoneState: - newPath := toNewPath(&top.oldPath) - if err := processBinary( - newPath, - &top.left, - &top.right, - top.leftHash, - lastHash, - sched, - batch, - ); err != nil { - return felt.Felt{}, batch, stack, err - } - lastHash = top.value - stack = stack[:len(stack)-1] - } - var err error - batch, err = flush(batch) - if err != nil { - return felt.Felt{}, batch, stack, err - } - } - return lastHash, batch, stack, nil -} - -func pushFrame(stack []dfsFrame, oldPath trie.BitArray) []dfsFrame { - n := len(stack) - stack = stack[:n+1] - stack[n] = dfsFrame{oldPath: oldPath} - return stack -} - -const maxOldKeySize = 1 + 32 + 1 + 32 - -func encodeOldPath(path *trie.BitArray, dst []byte) int { - pathLen := path.Len() - b := path.Bytes() - activeBytes := (uint(pathLen) + 7) / 8 - dst[0] = pathLen - copy(dst[1:], b[32-activeBytes:]) - return int(activeBytes) + 1 -} - -// readNode loads the deprecated-format node at (prefix, oldPath) and returns -// its parsed fields. The caller assigns the parsed values into its own state -// (e.g. a DFS stack frame); this function does not mutate any input. -func readNode(r db.KeyValueReader, prefix []byte, oldPath *trie.BitArray) (parsedNode, error) { - var arr [maxOldKeySize]byte - n := copy(arr[:], prefix) - n += encodeOldPath(oldPath, arr[n:]) - var node parsedNode - err := r.Get(arr[:n], func(val []byte) error { - var perr error - node, perr = parseNodeData(val) - return perr - }) - return node, err -} - -// parseNodeData decodes a deprecated-format node's raw bytes: -// felt(value) [ BitArray(left) BitArray(right) [ felt felt ] ] -// The trailing left/right hashes are ignored — the migrator re-derives hashes -// itself — so only the fields it actually needs are returned. -func parseNodeData(data []byte) (parsedNode, error) { - var n parsedNode - if len(data) < felt.Bytes { - return n, fmt.Errorf("trie: node data too short (%d bytes)", len(data)) - } - n.value = felt.FromBytes[felt.Felt](data[:felt.Bytes]) - data = data[felt.Bytes:] - if len(data) == 0 { - n.isLeaf = true - return n, nil - } - if err := n.left.UnmarshalBinary(data); err != nil { - return n, fmt.Errorf("trie: unmarshalling left path: %w", err) - } - data = data[n.left.EncodedLen():] - if err := n.right.UnmarshalBinary(data); err != nil { - return n, fmt.Errorf("trie: unmarshalling right path: %w", err) - } - return n, nil -} - -func processLeaf( - path trieutils.Path, - value *felt.Felt, - sched *hashScheduler, - batch db.Batch, -) error { - var buf [trieutils.MaxNodeKeySize + valueNodeBlobSize]byte - keyLen := trieutils.EncodeNodeKey(buf[:], sched.bucket, &sched.owner, &path, true) - blob := encodeValueNode(value) - copy(buf[keyLen:], blob[:]) - return batch.Put(buf[:keyLen], buf[keyLen:keyLen+valueNodeBlobSize]) -} - -func processBinary( - parentPath trieutils.Path, - left, right *trie.BitArray, - leftChildHash, rightChildHash felt.Felt, - sched *hashScheduler, - batch db.Batch, -) error { - leftSeg := compressedSegment(left, parentPath.Len()) - rightSeg := compressedSegment(right, parentPath.Len()) - return sched.schedule(&edgeHashJob{ - leftChildHash: leftChildHash, - leftSeg: leftSeg, - rightChildHash: rightChildHash, - rightSeg: rightSeg, - parentPath: parentPath, - }, batch) -} - -func writeRootEdge( - rootPath *trie.BitArray, - childHash felt.Felt, - sched *hashScheduler, - batch db.Batch, -) error { - seg := toNewPath(rootPath) - var buf [edgeNodeMaxSize]byte - n := encodeEdgeNodeInto(buf[:], &childHash, &seg) - return trieutils.WriteNodeByPath(batch, sched.bucket, &sched.owner, &trieutils.Path{}, false, buf[:n]) -} - -func oldTriePrefix(desc TrieDesc) []byte { - if desc.OldBucket == db.ContractStorage { - ownerFelt := felt.Felt(desc.Owner) - ownerBytes := ownerFelt.Bytes() - return desc.OldBucket.Key(ownerBytes[:]) - } - return desc.OldBucket.Key() -} diff --git a/migration/trie/dfsmigration_test.go b/migration/trie/dfsmigration_test.go deleted file mode 100644 index 79e220de27..0000000000 --- a/migration/trie/dfsmigration_test.go +++ /dev/null @@ -1,262 +0,0 @@ -package trie - -import ( - "context" - "math/rand" - "testing" - - "github.com/NethermindEth/juno/core/crypto" - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" - "github.com/NethermindEth/juno/core/trie2" - "github.com/NethermindEth/juno/core/trie2/triedb/rawdb" - "github.com/NethermindEth/juno/core/trie2/trienode" - "github.com/NethermindEth/juno/core/trie2/trieutils" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/db/memory" - "github.com/NethermindEth/juno/utils/log" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type leafMap map[felt.Felt]felt.Felt - -func nopLogger() log.StructuredLogger { return log.NewNopZapLogger() } -func noFlush(current db.Batch) (db.Batch, error) { return current, nil } - -type trieCase struct { - name string - oldBucket db.Bucket - newBucket db.Bucket - owner felt.Address - oldBuildPrefix func(owner felt.Address) []byte - newTrieID func(owner felt.Address) trieutils.TrieID - hashFn crypto.HashFn - //nolint:staticcheck // Necessary for old state - buildOldFn func(db.IndexedBatch, []byte, uint8) (*trie.Trie, error) -} - -var trieCases = []trieCase{ - { - name: "ClassTrie", - oldBucket: db.ClassesTrie, - newBucket: db.ClassTrie, - oldBuildPrefix: func(_ felt.Address) []byte { return []byte{byte(db.ClassesTrie)} }, - newTrieID: func(_ felt.Address) trieutils.TrieID { - return trieutils.NewClassTrieID(felt.StateRootHash(felt.One)) - }, - hashFn: crypto.Poseidon, - buildOldFn: trie.NewTriePoseidon, - }, - { - name: "ContractTrie", - oldBucket: db.StateTrie, - newBucket: db.ContractTrieContract, - oldBuildPrefix: func(_ felt.Address) []byte { return []byte{byte(db.StateTrie)} }, - newTrieID: func(_ felt.Address) trieutils.TrieID { - return trieutils.NewContractTrieID(felt.StateRootHash(felt.One)) - }, - hashFn: crypto.Pedersen, - buildOldFn: trie.NewTriePedersen, - }, - { - name: "StorageTrie", - oldBucket: db.ContractStorage, - newBucket: db.ContractTrieStorage, - owner: felt.FromUint64[felt.Address](42), - oldBuildPrefix: func(owner felt.Address) []byte { - ownerFelt := felt.Felt(owner) - ownerBytes := ownerFelt.Bytes() - return db.ContractStorage.Key(ownerBytes[:]) - }, - newTrieID: func(owner felt.Address) trieutils.TrieID { - return trieutils.NewContractStorageTrieID(felt.StateRootHash(felt.One), owner) - }, - hashFn: crypto.Pedersen, - buildOldFn: trie.NewTriePedersen, - }, -} - -// randomLeaves generates n distinct leaf key-value pairs using a fixed seed, -// with keys spread across the full 251-bit felt range for structural variety. -func randomLeaves(n int, seed int64) leafMap { - rng := rand.New(rand.NewSource(seed)) - leaves := make(leafMap, n) - var kb, vb [32]byte - for len(leaves) < n { - rng.Read(kb[:]) - rng.Read(vb[:]) - // Clear the top 5 bits so all keys are safely below the StarkNet prime (~2^251+δ). - kb[0] &= 0x07 - k := felt.FromBytes[felt.Felt](kb[:]) - v := felt.FromBytes[felt.Felt](vb[:]) - leaves[k] = v - } - return leaves -} - -var transcoderCases = []struct { - name string - leaves leafMap -}{ - // --- trivial structural cases --- - {"empty trie", nil}, - {"single leaf", leafMap{ - felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](100), - }}, - - // --- two-leaf cases probing specific split depths --- - - // Keys 2 and 3 differ only in the last path bit (bit 250 of felt = path bit 250). - // Root edge spans 250 bits; binary node is at maximum depth. - {"deep split", leafMap{ - felt.FromUint64[felt.Felt](2): felt.FromUint64[felt.Felt](10), - felt.FromUint64[felt.Felt](3): felt.FromUint64[felt.Felt](20), - }}, - - // One key < 2^250 (trie path bit[0]=0), one key ≥ 2^250 (trie path bit[0]=1). - // Root is a binary node — rootPath.Len()==0, no root edge written. - // This path is never hit by sequential small-integer keys. - {"left right split", leafMap{ - felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](10), - felt.FromBytes[felt.Felt]([]byte{0x04}): felt.FromUint64[felt.Felt](20), // 2^250 - }}, - - // --- four leaves covering all 2-bit path prefixes (00, 01, 10, 11) --- - // Root is a binary node; left and right subtrees each contain a binary node. - // Tests two levels of binary processing with rootPath.Len()==0. - {"full depth 2 tree", leafMap{ - felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](10), // 00... - felt.FromBytes[felt.Felt]([]byte{0x02}): felt.FromUint64[felt.Felt](20), // 01... (2^249) - felt.FromBytes[felt.Felt]([]byte{0x04}): felt.FromUint64[felt.Felt](30), // 10... (2^250) - felt.FromBytes[felt.Felt]([]byte{0x06}): felt.FromUint64[felt.Felt](40), // 11... (2^250+2^249) - }}, - - // --- larger sequential case for basic load --- - {"hundred sequential leaves", func() leafMap { - leaves := make(leafMap, 100) - for i := 1; i <= 100; i++ { - leaves[felt.FromUint64[felt.Felt](uint64(i))] = felt.FromUint64[felt.Felt](uint64(i) * 7) - } - return leaves - }()}, - - // --- random leaves spanning the full 251-bit space --- - // Keys are evenly distributed across all trie depths, exercising every code - // path: balanced binary nodes, varied edge lengths, and batch flush thresholds. - {"random 1000 leaves", randomLeaves(1000, 42)}, -} - -// TestMigrateTrieMatchesNativeTrie2 verifies that each backend produces byte-for-byte -// identical DB output to a natively-built trie2 for all three trie types and all leaf -// counts. This catches encoding bugs that root-hash comparison cannot detect. -func TestMigrationEndToEnd(t *testing.T) { - type testCase struct { - name string - tc trieCase - leaves leafMap - } - - var cases []testCase - for _, tc := range trieCases { - for _, lc := range transcoderCases { - cases = append(cases, testCase{ - name: tc.name + "/" + lc.name, - tc: tc, - leaves: lc.leaves, - }) - } - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - prefix := c.tc.oldBuildPrefix(c.tc.owner) - - migratedDB := memory.New() - buildDeprecatedTrie(t, migratedDB, c.leaves, c.tc.buildOldFn, prefix) - _, err := runMigration(context.Background(), migratedDB, nopLogger()) - require.NoError(t, err) - - nativeDB := memory.New() - buildTrie(t, nativeDB, c.leaves, - c.tc.newTrieID(c.tc.owner), c.tc.hashFn, c.tc.newBucket) - - assert.Equal(t, - allKeysUnder(t, nativeDB, c.tc.newBucket), - allKeysUnder(t, migratedDB, c.tc.newBucket)) - }) - } -} - -func buildDeprecatedTrie( - t *testing.T, - database db.KeyValueStore, - leaves leafMap, - //nolint:staticcheck // Necessary for old state - trieFn func(db.IndexedBatch, []byte, uint8) (*trie.Trie, error), - prefix []byte, -) felt.Felt { - t.Helper() - //nolint:staticcheck // Necessary for old state - txn := database.NewIndexedBatch() - tr, err := trieFn(txn, prefix, 251) - require.NoError(t, err) - for key, value := range leaves { - _, err := tr.Put(&key, &value) - require.NoError(t, err) - } - root, err := tr.Root() - require.NoError(t, err) - require.NoError(t, tr.Commit()) - require.NoError(t, txn.Write()) - return root -} - -// buildNativeTrie2 builds a trie2 natively from leaves and persists it to kvStore. -// newBucket distinguishes class trie (db.ClassTrie) from contract/storage tries — -// it controls which Update argument the NodeSet is passed as. -func buildTrie( - t *testing.T, - kvStore db.KeyValueStore, - leaves leafMap, - id trieutils.TrieID, - hashFn crypto.HashFn, - newBucket db.Bucket, -) { - t.Helper() - rawDB := rawdb.New(kvStore) - tr, err := trie2.New(id, 251, hashFn, rawDB) - require.NoError(t, err) - for key, value := range leaves { - require.NoError(t, tr.Update(&key, &value)) - } - root, nodes := tr.Commit() - if nodes == nil { - return // empty trie — nothing to persist - } - mergeSet := trienode.NewMergeNodeSet(nodes) - var zero felt.StateRootHash - stateRoot := felt.StateRootHash(root) - batch := kvStore.NewBatch() - if newBucket == db.ClassTrie { - require.NoError(t, rawDB.Update(&stateRoot, &zero, 0, mergeSet, nil, batch)) - } else { - require.NoError(t, rawDB.Update(&stateRoot, &zero, 0, nil, mergeSet, batch)) - } - require.NoError(t, batch.Write()) -} - -func allKeysUnder(t *testing.T, r db.KeyValueReader, bucket db.Bucket) map[string][]byte { - t.Helper() - prefix := bucket.Key() - iter, err := r.NewIterator(prefix, true) - require.NoError(t, err) - defer iter.Close() - out := make(map[string][]byte) - for ok := iter.First(); ok; ok = iter.Next() { - val, err := iter.Value() - require.NoError(t, err) - out[string(iter.Key())] = val - } - return out -} diff --git a/migration/trie/ingestor.go b/migration/trie/ingestor.go index e93e49dbbc..328d07cb33 100644 --- a/migration/trie/ingestor.go +++ b/migration/trie/ingestor.go @@ -4,17 +4,31 @@ import ( "context" "fmt" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/migration/semaphore" ) +const ( + dfsStackCap = 251 + maxOldKeySize = 1 + 32 + 1 + 32 +) + +type traverseStackState uint8 + +const ( + readNodeState traverseStackState = iota + leftSubtreeDoneState + rightSubtreeDoneState +) + type task struct { batch db.Batch tries int } -const dfsStackCap = 251 - type ingestor struct { ctx context.Context database db.KeyValueReader @@ -24,8 +38,6 @@ type ingestor struct { dfsStacks [IngestorCount][]dfsFrame } -type FlushBatchFn func(db.Batch) (db.Batch, error) - func newIngestor( ctx context.Context, database db.KeyValueReader, @@ -56,34 +68,15 @@ func (i *ingestor) Run(index int, desc TrieDesc, outputs chan<- task) error { t := &i.tasks[index] - flush := FlushBatchFn(func(current db.Batch) (db.Batch, error) { - if current.Size() < targetBatchByteSize { - return current, nil - } - select { - case <-i.ctx.Done(): - return current, i.ctx.Err() - case outputs <- task{batch: current, tries: t.tries}: - } - t.tries = 0 - return i.batchSemaphore.GetBlocking(), nil - }) - - migrator := newDFSMigrator(desc.NodeCount >= SmallTrieThreshold, i.pool) - t.batch, i.dfsStacks[index], err = migrator.Migrate( - i.database, - t.batch, - desc, - flush, - i.dfsStacks[index], + i.dfsStacks[index], err = migrateTrie( + i.database, desc, i.pool, t, i.flush, outputs, i.dfsStacks[index], ) if err != nil { return err } t.tries++ - t.batch, err = flush(t.batch) - return err + return i.flush(t, outputs) } func (i *ingestor) Done(index int, outputs chan<- task) error { @@ -94,3 +87,251 @@ func (i *ingestor) Done(index int, outputs chan<- task) error { } return nil } + +// flush rotates t.batch if it's hit target size: sends the current batch +// downstream and acquires a fresh one. Mutates t in place. +func (i *ingestor) flush(t *task, outputs chan<- task) error { + if t.batch.Size() < targetBatchByteSize { + return nil + } + select { + case <-i.ctx.Done(): + return i.ctx.Err() + case outputs <- task{batch: t.batch, tries: t.tries}: + } + t.tries = 0 + t.batch = i.batchSemaphore.GetBlocking() + return nil +} + +// migrateTrie writes the new-format representation of a single deprecated +// trie into t.batch. It walks the deprecated trie in DFS order, emitting +// value/binary/edge nodes via the hashScheduler, and calls flush once per +// step to rotate the batch when it hits target size. +func migrateTrie( + r db.KeyValueReader, + desc TrieDesc, + pool *hashWorkerPool, + t *task, + flush func(*task, chan<- task) error, + outputs chan<- task, + stack []dfsFrame, +) ([]dfsFrame, error) { + stack = stack[:0] + if desc.RootPath == nil { + return stack, nil + } + parallelDispatch := desc.NodeCount >= SmallTrieThreshold + prefix := oldTriePrefix(desc) + sched := newHashScheduler(desc.HashFn, parallelDispatch, desc.NewBucket, desc.Owner, pool) + + rootHash, stack, err := traverse(r, prefix, desc.RootPath, sched, t, flush, outputs, stack) + if err != nil { + return stack, err + } + if err := sched.sync(t.batch); err != nil { + return stack, err + } + if desc.RootPath.Len() > 0 { + if err := writeRootEdge(desc.RootPath, rootHash, sched, t.batch); err != nil { + return stack, err + } + } + return stack, nil +} + +type dfsFrame struct { + oldPath trie.BitArray + left trie.BitArray + right trie.BitArray + value felt.Felt + leftHash felt.Felt + isLeaf bool + state traverseStackState +} + +type parsedNode struct { + value felt.Felt + left trie.BitArray + right trie.BitArray + isLeaf bool +} + +func traverse( + r db.KeyValueReader, + prefix []byte, + start *trie.BitArray, + sched *hashScheduler, + t *task, + flush func(*task, chan<- task) error, + outputs chan<- task, + stack []dfsFrame, +) (felt.Felt, []dfsFrame, error) { + stack = stack[:1] + stack[0] = dfsFrame{oldPath: *start} + var lastHash felt.Felt + for len(stack) > 0 { + top := &stack[len(stack)-1] + switch top.state { + case readNodeState: + parsed, err := readNode(r, prefix, &top.oldPath) + if err != nil { + return felt.Felt{}, stack, err + } + top.value = parsed.value + top.left = parsed.left + top.right = parsed.right + top.isLeaf = parsed.isLeaf + if top.isLeaf { + newPath := toNewPath(&top.oldPath) + if err := processLeaf(newPath, &top.value, sched, t.batch); err != nil { + return felt.Felt{}, stack, err + } + lastHash = top.value + stack = stack[:len(stack)-1] + } else { + top.state = leftSubtreeDoneState + stack = pushFrame(stack, top.left) + } + case leftSubtreeDoneState: + top.leftHash = lastHash + top.state = rightSubtreeDoneState + stack = pushFrame(stack, top.right) + case rightSubtreeDoneState: + newPath := toNewPath(&top.oldPath) + if err := processBinary( + newPath, + &top.left, + &top.right, + top.leftHash, + lastHash, + sched, + t.batch, + ); err != nil { + return felt.Felt{}, stack, err + } + lastHash = top.value + stack = stack[:len(stack)-1] + } + if err := flush(t, outputs); err != nil { + return felt.Felt{}, stack, err + } + } + return lastHash, stack, nil +} + +func pushFrame(stack []dfsFrame, oldPath trie.BitArray) []dfsFrame { + n := len(stack) + stack = stack[:n+1] + stack[n] = dfsFrame{oldPath: oldPath} + return stack +} + +func encodeOldPath(path *trie.BitArray, dst []byte) int { + pathLen := path.Len() + b := path.Bytes() + activeBytes := (uint(pathLen) + 7) / 8 + dst[0] = pathLen + copy(dst[1:], b[32-activeBytes:]) + return int(activeBytes) + 1 +} + +// readNode loads the deprecated-format node at (prefix, oldPath) and returns +// its parsed fields. The caller assigns the parsed values into its own state +// (e.g. a DFS stack frame); this function does not mutate any input. +func readNode(r db.KeyValueReader, prefix []byte, oldPath *trie.BitArray) (parsedNode, error) { + var arr [maxOldKeySize]byte + n := copy(arr[:], prefix) + n += encodeOldPath(oldPath, arr[n:]) + var node parsedNode + err := r.Get(arr[:n], func(val []byte) error { + var perr error + node, perr = parseNodeData(val) + return perr + }) + return node, err +} + +// parseNodeData decodes a deprecated-format node's raw bytes: +// felt(value) [ BitArray(left) BitArray(right) [ felt felt ] ] +// The trailing left/right hashes are ignored — the migrator re-derives hashes +// itself — so only the fields it actually needs are returned. +func parseNodeData(data []byte) (parsedNode, error) { + var n parsedNode + if len(data) < felt.Bytes { + return n, fmt.Errorf("trie: node data too short (%d bytes)", len(data)) + } + n.value = felt.FromBytes[felt.Felt](data[:felt.Bytes]) + data = data[felt.Bytes:] + if len(data) == 0 { + n.isLeaf = true + return n, nil + } + if err := n.left.UnmarshalBinary(data); err != nil { + return n, fmt.Errorf("trie: unmarshalling left path: %w", err) + } + data = data[n.left.EncodedLen():] + if err := n.right.UnmarshalBinary(data); err != nil { + return n, fmt.Errorf("trie: unmarshalling right path: %w", err) + } + return n, nil +} + +func processLeaf( + path trieutils.Path, + value *felt.Felt, + sched *hashScheduler, + batch db.Batch, +) error { + var buf [trieutils.MaxNodeKeySize + valueNodeBlobSize]byte + keyLen := trieutils.EncodeNodeKey(buf[:], sched.bucket, &sched.owner, &path, true) + blob := encodeValueNode(value) + copy(buf[keyLen:], blob[:]) + return batch.Put(buf[:keyLen], buf[keyLen:keyLen+valueNodeBlobSize]) +} + +func processBinary( + parentPath trieutils.Path, + left, right *trie.BitArray, + leftChildHash, rightChildHash felt.Felt, + sched *hashScheduler, + batch db.Batch, +) error { + leftSeg := compressedSegment(left, parentPath.Len()) + rightSeg := compressedSegment(right, parentPath.Len()) + return sched.schedule(&edgeHashJob{ + leftChildHash: leftChildHash, + leftSeg: leftSeg, + rightChildHash: rightChildHash, + rightSeg: rightSeg, + parentPath: parentPath, + }, batch) +} + +func writeRootEdge( + rootPath *trie.BitArray, + childHash felt.Felt, + sched *hashScheduler, + batch db.Batch, +) error { + seg := toNewPath(rootPath) + var buf [edgeNodeMaxSize]byte + n := encodeEdgeNodeInto(buf[:], &childHash, &seg) + return trieutils.WriteNodeByPath( + batch, + sched.bucket, + &sched.owner, + &trieutils.Path{}, + false, + buf[:n], + ) +} + +func oldTriePrefix(desc TrieDesc) []byte { + if desc.OldBucket == db.ContractStorage { + ownerFelt := felt.Felt(desc.Owner) + ownerBytes := ownerFelt.Bytes() + return desc.OldBucket.Key(ownerBytes[:]) + } + return desc.OldBucket.Key() +} diff --git a/migration/trie/trie_test.go b/migration/trie/trie_test.go index 630bed62fe..840c55aa2d 100644 --- a/migration/trie/trie_test.go +++ b/migration/trie/trie_test.go @@ -3,18 +3,152 @@ package trie import ( "bytes" "context" + "math/rand" "slices" "testing" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/triedb/rawdb" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" + "github.com/NethermindEth/juno/utils/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +type leafMap map[felt.Felt]felt.Felt + +func nopLogger() log.StructuredLogger { return log.NewNopZapLogger() } +func nopFlush(*task, chan<- task) error { return nil } + +type trieCase struct { + name string + oldBucket db.Bucket + newBucket db.Bucket + owner felt.Address + oldBuildPrefix func(owner felt.Address) []byte + newTrieID func(owner felt.Address) trieutils.TrieID + hashFn crypto.HashFn + //nolint:staticcheck // Necessary for old state + buildOldFn func(db.IndexedBatch, []byte, uint8) (*trie.Trie, error) +} + +var trieCases = []trieCase{ + { + name: "ClassTrie", + oldBucket: db.ClassesTrie, + newBucket: db.ClassTrie, + oldBuildPrefix: func(_ felt.Address) []byte { return []byte{byte(db.ClassesTrie)} }, + newTrieID: func(_ felt.Address) trieutils.TrieID { + return trieutils.NewClassTrieID(felt.StateRootHash(felt.One)) + }, + hashFn: crypto.Poseidon, + buildOldFn: trie.NewTriePoseidon, + }, + { + name: "ContractTrie", + oldBucket: db.StateTrie, + newBucket: db.ContractTrieContract, + oldBuildPrefix: func(_ felt.Address) []byte { return []byte{byte(db.StateTrie)} }, + newTrieID: func(_ felt.Address) trieutils.TrieID { + return trieutils.NewContractTrieID(felt.StateRootHash(felt.One)) + }, + hashFn: crypto.Pedersen, + buildOldFn: trie.NewTriePedersen, + }, + { + name: "StorageTrie", + oldBucket: db.ContractStorage, + newBucket: db.ContractTrieStorage, + owner: felt.FromUint64[felt.Address](42), + oldBuildPrefix: func(owner felt.Address) []byte { + ownerFelt := felt.Felt(owner) + ownerBytes := ownerFelt.Bytes() + return db.ContractStorage.Key(ownerBytes[:]) + }, + newTrieID: func(owner felt.Address) trieutils.TrieID { + return trieutils.NewContractStorageTrieID(felt.StateRootHash(felt.One), owner) + }, + hashFn: crypto.Pedersen, + buildOldFn: trie.NewTriePedersen, + }, +} + +// randomLeaves generates n distinct leaf key-value pairs using a fixed seed, +// with keys spread across the full 251-bit felt range for structural variety. +func randomLeaves(n int, seed int64) leafMap { + rng := rand.New(rand.NewSource(seed)) + leaves := make(leafMap, n) + var kb, vb [32]byte + for len(leaves) < n { + rng.Read(kb[:]) + rng.Read(vb[:]) + // Clear the top 5 bits so all keys are safely below the StarkNet prime (~2^251+δ). + kb[0] &= 0x07 + k := felt.FromBytes[felt.Felt](kb[:]) + v := felt.FromBytes[felt.Felt](vb[:]) + leaves[k] = v + } + return leaves +} + +var transcoderCases = []struct { + name string + leaves leafMap +}{ + // --- trivial structural cases --- + {"empty trie", nil}, + {"single leaf", leafMap{ + felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](100), + }}, + + // --- two-leaf cases probing specific split depths --- + + // Keys 2 and 3 differ only in the last path bit (bit 250 of felt = path bit 250). + // Root edge spans 250 bits; binary node is at maximum depth. + {"deep split", leafMap{ + felt.FromUint64[felt.Felt](2): felt.FromUint64[felt.Felt](10), + felt.FromUint64[felt.Felt](3): felt.FromUint64[felt.Felt](20), + }}, + + // One key < 2^250 (trie path bit[0]=0), one key ≥ 2^250 (trie path bit[0]=1). + // Root is a binary node — rootPath.Len()==0, no root edge written. + // This path is never hit by sequential small-integer keys. + {"left right split", leafMap{ + felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](10), + felt.FromBytes[felt.Felt]([]byte{0x04}): felt.FromUint64[felt.Felt](20), // 2^250 + }}, + + // --- four leaves covering all 2-bit path prefixes (00, 01, 10, 11) --- + // Root is a binary node; left and right subtrees each contain a binary node. + // Tests two levels of binary processing with rootPath.Len()==0. + {"full depth 2 tree", leafMap{ + felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](10), // 00... + felt.FromBytes[felt.Felt]([]byte{0x02}): felt.FromUint64[felt.Felt](20), // 01... (2^249) + felt.FromBytes[felt.Felt]([]byte{0x04}): felt.FromUint64[felt.Felt](30), // 10... (2^250) + felt.FromBytes[felt.Felt]([]byte{0x06}): felt.FromUint64[felt.Felt](40), // 11... (2^250+2^249) + }}, + + // --- larger sequential case for basic load --- + {"hundred sequential leaves", func() leafMap { + leaves := make(leafMap, 100) + for i := 1; i <= 100; i++ { + leaves[felt.FromUint64[felt.Felt](uint64(i))] = felt.FromUint64[felt.Felt](uint64(i) * 7) + } + return leaves + }()}, + + // --- random leaves spanning the full 251-bit space --- + // Keys are evenly distributed across all trie depths, exercising every code + // path: balanced binary nodes, varied edge lengths, and batch flush thresholds. + {"random 1000 leaves", randomLeaves(1000, 42)}, +} + func TestMigrate_FreshDBIsNoOp(t *testing.T) { memDB := memory.New() @@ -44,6 +178,47 @@ func TestMigrate_RunsWhenOldDataPresent(t *testing.T) { assert.False(t, stillNeeded, "old-format buckets should be empty after migration") } +// TestMigrationEndToEnd verifies that the migration produces byte-for-byte +// identical DB output to a natively-built trie2 for all three trie types and +// all leaf counts. Catches encoding bugs that root-hash comparison cannot. +func TestMigrationEndToEnd(t *testing.T) { + type testCase struct { + name string + tc trieCase + leaves leafMap + } + + var cases []testCase + for _, tc := range trieCases { + for _, lc := range transcoderCases { + cases = append(cases, testCase{ + name: tc.name + "/" + lc.name, + tc: tc, + leaves: lc.leaves, + }) + } + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + prefix := c.tc.oldBuildPrefix(c.tc.owner) + + migratedDB := memory.New() + buildDeprecatedTrie(t, migratedDB, c.leaves, c.tc.buildOldFn, prefix) + _, err := runMigration(context.Background(), migratedDB, nopLogger()) + require.NoError(t, err) + + nativeDB := memory.New() + buildTrie(t, nativeDB, c.leaves, + c.tc.newTrieID(c.tc.owner), c.tc.hashFn, c.tc.newBucket) + + assert.Equal(t, + allKeysUnder(t, nativeDB, c.tc.newBucket), + allKeysUnder(t, migratedDB, c.tc.newBucket)) + }) + } +} + func TestMigrationIsResumable(t *testing.T) { leaves := randomLeaves(1000, 42) @@ -70,15 +245,14 @@ func TestMigrationIsResumable(t *testing.T) { NodeCount: len(leaves), RootPath: classRootPath, } - batch := partialDB.NewBatch() pool := newHashWorkerPool() defer pool.close() - migrator := newDFSMigrator(true, pool) + t1 := &task{batch: partialDB.NewBatch()} stack := make([]dfsFrame, 0, dfsStackCap) - _, _, err = migrator.Migrate(partialDB, batch, classDesc, noFlush, stack) + _, err = migrateTrie(partialDB, classDesc, pool, t1, nopFlush, nil, stack) require.NoError(t, err) - require.NoError(t, batch.Write()) + require.NoError(t, t1.batch.Write()) done, err := hasDestRoot(partialDB, db.ClassTrie, &classDesc.Owner) require.NoError(t, err) @@ -136,6 +310,79 @@ func buildFullDB(t *testing.T, leaves leafMap) db.KeyValueStore { return database } +func buildDeprecatedTrie( + t *testing.T, + database db.KeyValueStore, + leaves leafMap, + //nolint:staticcheck // Necessary for old state + trieFn func(db.IndexedBatch, []byte, uint8) (*trie.Trie, error), + prefix []byte, +) felt.Felt { + t.Helper() + //nolint:staticcheck // Necessary for old state + txn := database.NewIndexedBatch() + tr, err := trieFn(txn, prefix, 251) + require.NoError(t, err) + for key, value := range leaves { + _, err := tr.Put(&key, &value) + require.NoError(t, err) + } + root, err := tr.Root() + require.NoError(t, err) + require.NoError(t, tr.Commit()) + require.NoError(t, txn.Write()) + return root +} + +// buildTrie builds a trie2 natively from leaves and persists it to kvStore. +// newBucket distinguishes class trie (db.ClassTrie) from contract/storage tries — +// it controls which Update argument the NodeSet is passed as. +func buildTrie( + t *testing.T, + kvStore db.KeyValueStore, + leaves leafMap, + id trieutils.TrieID, + hashFn crypto.HashFn, + newBucket db.Bucket, +) { + t.Helper() + rawDB := rawdb.New(kvStore) + tr, err := trie2.New(id, 251, hashFn, rawDB) + require.NoError(t, err) + for key, value := range leaves { + require.NoError(t, tr.Update(&key, &value)) + } + root, nodes := tr.Commit() + if nodes == nil { + return // empty trie — nothing to persist + } + mergeSet := trienode.NewMergeNodeSet(nodes) + var zero felt.StateRootHash + stateRoot := felt.StateRootHash(root) + batch := kvStore.NewBatch() + if newBucket == db.ClassTrie { + require.NoError(t, rawDB.Update(&stateRoot, &zero, 0, mergeSet, nil, batch)) + } else { + require.NoError(t, rawDB.Update(&stateRoot, &zero, 0, nil, mergeSet, batch)) + } + require.NoError(t, batch.Write()) +} + +func allKeysUnder(t *testing.T, r db.KeyValueReader, bucket db.Bucket) map[string][]byte { + t.Helper() + prefix := bucket.Key() + iter, err := r.NewIterator(prefix, true) + require.NoError(t, err) + defer iter.Close() + out := make(map[string][]byte) + for ok := iter.First(); ok; ok = iter.Next() { + val, err := iter.Value() + require.NoError(t, err) + out[string(iter.Key())] = val + } + return out +} + func insertFakeStorageNodes(t *testing.T, database db.KeyValueStore, owner felt.Address, n int) { t.Helper() ownerFelt := felt.Felt(owner) From ca2aadb9883c37cec1b88a2a2fa6fe261f56f57c Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 20 May 2026 13:01:24 +0200 Subject: [PATCH 10/14] chore: self review, cleanups --- migration/trie/codec_test.go | 8 +- migration/trie/committer.go | 5 +- migration/trie/counter.go | 52 ++++++--- migration/trie/hashworker.go | 38 +++---- migration/trie/ingestor.go | 199 ++++++++++++++--------------------- migration/trie/trie.go | 125 +++++++++++----------- migration/trie/trie_test.go | 39 +++---- 7 files changed, 231 insertions(+), 235 deletions(-) diff --git a/migration/trie/codec_test.go b/migration/trie/codec_test.go index 9a4f53722a..7f0e201861 100644 --- a/migration/trie/codec_test.go +++ b/migration/trie/codec_test.go @@ -136,8 +136,8 @@ func TestCompressedSegment_Length(t *testing.T) { } func TestOldTriePrefix_GlobalTrie(t *testing.T) { - desc := TrieDesc{OldBucket: db.ClassesTrie, Owner: felt.Address{}} - prefix := oldTriePrefix(desc) + desc := TrieDesc{DeprecatedTrieBucket: db.ClassesTrie, Owner: felt.Address{}} + prefix := deprecatedTriePrefix(desc) assert.Equal(t, []byte{byte(db.ClassesTrie)}, prefix) } @@ -145,8 +145,8 @@ func TestOldTriePrefix_StorageTrie(t *testing.T) { var ownerFelt felt.Felt ownerFelt.SetUint64(42) owner := felt.Address(ownerFelt) - desc := TrieDesc{OldBucket: db.ContractStorage, Owner: owner} - prefix := oldTriePrefix(desc) + desc := TrieDesc{DeprecatedTrieBucket: db.ContractStorage, Owner: owner} + prefix := deprecatedTriePrefix(desc) assert.Equal(t, byte(db.ContractStorage), prefix[0]) assert.Equal(t, 1+felt.Bytes, len(prefix)) } diff --git a/migration/trie/committer.go b/migration/trie/committer.go index d3a29ecbd9..4086e3ff7e 100644 --- a/migration/trie/committer.go +++ b/migration/trie/committer.go @@ -16,10 +16,11 @@ type committer struct { func newCommitter( logger log.StructuredLogger, batchSem semaphore.ResourceSemaphore[db.Batch], + allTries, allNodes uint64, ) *committer { return &committer{ batchSem: batchSem, - counter: newCounter(logger, timeLogRate), + counter: newCounter(logger, timeLogRate, allTries, allNodes), } } @@ -28,7 +29,7 @@ func (c *committer) Run(_ int, t task, _ chan<- struct{}) error { if err := t.batch.Write(); err != nil { return fmt.Errorf("trie migration: batch write failed: %w", err) } - c.counter.log(byteSize, t.tries) + c.counter.log(byteSize, t.tries, t.nodes) c.batchSem.Put() return nil } diff --git a/migration/trie/counter.go b/migration/trie/counter.go index 58978cba12..88cdd47d19 100644 --- a/migration/trie/counter.go +++ b/migration/trie/counter.go @@ -1,6 +1,7 @@ package trie import ( + "fmt" "time" "github.com/NethermindEth/juno/db" @@ -9,24 +10,41 @@ import ( ) type counter struct { - logger log.StructuredLogger - timeLogRate time.Duration - start time.Time - size uint64 - tries uint64 + logger log.StructuredLogger + timeLogRate time.Duration + migrationStart time.Time + allTries uint64 + allNodes uint64 + totalTries uint64 + totalNodes uint64 + start time.Time + size uint64 + tries uint64 + nodes uint64 } -func newCounter(logger log.StructuredLogger, timeLogRate time.Duration) counter { +func newCounter( + logger log.StructuredLogger, + timeLogRate time.Duration, + allTries, allNodes uint64, +) counter { + now := time.Now() return counter{ - logger: logger, - timeLogRate: timeLogRate, - start: time.Now(), + logger: logger, + timeLogRate: timeLogRate, + migrationStart: now, + start: now, + allTries: allTries, + allNodes: allNodes, } } -func (c *counter) log(byteSize uint64, tries int) { +func (c *counter) log(byteSize uint64, tries, nodes int) { c.size += byteSize c.tries += uint64(tries) + c.nodes += uint64(nodes) + c.totalTries += uint64(tries) + c.totalNodes += uint64(nodes) now := time.Now() elapsed := now.Sub(c.start).Seconds() @@ -36,12 +54,22 @@ func (c *counter) log(byteSize uint64, tries int) { "write speed", zap.Float64("MB", mbs), zap.Float64("MB/s", mbs/elapsed), - zap.Uint64("tries", c.tries), + zap.Float64("nodes/s", float64(c.nodes)/elapsed), zap.Float64("tries/s", float64(c.tries)/elapsed), - zap.Float64("time", elapsed), + zap.String("tries_processed", fmtPercent(c.totalTries, c.allTries)), + zap.String("nodes_processed", fmtPercent(c.totalNodes, c.allNodes)), + zap.Float64("totalTime", now.Sub(c.migrationStart).Seconds()), ) c.start = now c.size = 0 c.tries = 0 + c.nodes = 0 + } +} + +func fmtPercent(done, total uint64) string { + if total == 0 { + return "100.0%" } + return fmt.Sprintf("%.1f%%", 100.0*float64(done)/float64(total)) } diff --git a/migration/trie/hashworker.go b/migration/trie/hashworker.go index f0ed866b16..44aba742cd 100644 --- a/migration/trie/hashworker.go +++ b/migration/trie/hashworker.go @@ -142,26 +142,26 @@ func (s *hashScheduler) writeBinaryAndEdges( return err } - if job.leftSeg.Len() > 0 { - var leftEdgePath trieutils.Path - leftEdgePath.AppendBit(&job.parentPath, 0) - var ebuf [trieutils.MaxNodeKeySize + edgeNodeMaxSize]byte - kl := trieutils.EncodeNodeKey(ebuf[:], s.bucket, &s.owner, &leftEdgePath, false) - edgeBlob := encodeEdgeNodeInto(ebuf[kl:], &job.leftChildHash, &job.leftSeg) - if err := batch.Put(ebuf[:kl], ebuf[kl:kl+edgeBlob]); err != nil { - return err - } + if err := s.writeEdge(&job.parentPath, 0, &job.leftChildHash, &job.leftSeg, batch); err != nil { + return err } + return s.writeEdge(&job.parentPath, 1, &job.rightChildHash, &job.rightSeg, batch) +} - if job.rightSeg.Len() > 0 { - var rightEdgePath trieutils.Path - rightEdgePath.AppendBit(&job.parentPath, 1) - var ebuf [trieutils.MaxNodeKeySize + edgeNodeMaxSize]byte - kl := trieutils.EncodeNodeKey(ebuf[:], s.bucket, &s.owner, &rightEdgePath, false) - edgeBlob := encodeEdgeNodeInto(ebuf[kl:], &job.rightChildHash, &job.rightSeg) - if err := batch.Put(ebuf[:kl], ebuf[kl:kl+edgeBlob]); err != nil { - return err - } +func (s *hashScheduler) writeEdge( + parentPath *trieutils.Path, + bit uint8, + childHash *felt.Felt, + seg *trieutils.Path, + batch db.Batch, +) error { + if seg.Len() == 0 { + return nil } - return nil + var edgePath trieutils.Path + edgePath.AppendBit(parentPath, bit) + var ebuf [trieutils.MaxNodeKeySize + edgeNodeMaxSize]byte + kl := trieutils.EncodeNodeKey(ebuf[:], s.bucket, &s.owner, &edgePath, false) + blob := encodeEdgeNodeInto(ebuf[kl:], childHash, seg) + return batch.Put(ebuf[:kl], ebuf[kl:kl+blob]) } diff --git a/migration/trie/ingestor.go b/migration/trie/ingestor.go index 328d07cb33..16702ae443 100644 --- a/migration/trie/ingestor.go +++ b/migration/trie/ingestor.go @@ -11,22 +11,12 @@ import ( "github.com/NethermindEth/juno/migration/semaphore" ) -const ( - dfsStackCap = 251 - maxOldKeySize = 1 + 32 + 1 + 32 -) - -type traverseStackState uint8 - -const ( - readNodeState traverseStackState = iota - leftSubtreeDoneState - rightSubtreeDoneState -) +const maxOldKeySize = 1 + 32 + 1 + 32 type task struct { batch db.Batch tries int + nodes int } type ingestor struct { @@ -35,7 +25,6 @@ type ingestor struct { batchSemaphore semaphore.ResourceSemaphore[db.Batch] pool *hashWorkerPool tasks [IngestorCount]task - dfsStacks [IngestorCount][]dfsFrame } func newIngestor( @@ -52,26 +41,25 @@ func newIngestor( } for i := range IngestorCount { in.tasks[i].batch = batchSemaphore.GetBlocking() - in.dfsStacks[i] = make([]dfsFrame, 0, dfsStackCap) } return in } func (i *ingestor) Run(index int, desc TrieDesc, outputs chan<- task) error { - done, err := hasDestRoot(i.database, desc.NewBucket, &desc.Owner) + done, err := hasDestRoot(i.database, desc.TrieBucket, &desc.Owner) if err != nil { - return fmt.Errorf("hasDestRoot(%v, %x): %w", desc.NewBucket, desc.Owner, err) - } - if done { - return nil + return fmt.Errorf("hasDestRoot(%v, %x): %w", desc.TrieBucket, desc.Owner, err) } t := &i.tasks[index] + if done { + // Already migrated — credit the counts so progress display reaches 100% on resume. + t.tries++ + t.nodes += desc.NodeCount + return i.flush(t, outputs) + } - i.dfsStacks[index], err = migrateTrie( - i.database, desc, i.pool, t, i.flush, outputs, i.dfsStacks[index], - ) - if err != nil { + if err := migrateTrie(i.database, desc, i.pool, t, i.flush, outputs); err != nil { return err } @@ -97,16 +85,17 @@ func (i *ingestor) flush(t *task, outputs chan<- task) error { select { case <-i.ctx.Done(): return i.ctx.Err() - case outputs <- task{batch: t.batch, tries: t.tries}: + case outputs <- task{batch: t.batch, tries: t.tries, nodes: t.nodes}: } t.tries = 0 + t.nodes = 0 t.batch = i.batchSemaphore.GetBlocking() return nil } // migrateTrie writes the new-format representation of a single deprecated -// trie into t.batch. It walks the deprecated trie in DFS order, emitting -// value/binary/edge nodes via the hashScheduler, and calls flush once per +// trie into t.batch. It walks the deprecated trie via DFS, emits value / +// binary / edge nodes through the hashScheduler, and calls flush at each // step to rotate the batch when it hits target size. func migrateTrie( r db.KeyValueReader, @@ -115,116 +104,77 @@ func migrateTrie( t *task, flush func(*task, chan<- task) error, outputs chan<- task, - stack []dfsFrame, -) ([]dfsFrame, error) { - stack = stack[:0] +) error { if desc.RootPath == nil { - return stack, nil + return nil } parallelDispatch := desc.NodeCount >= SmallTrieThreshold - prefix := oldTriePrefix(desc) - sched := newHashScheduler(desc.HashFn, parallelDispatch, desc.NewBucket, desc.Owner, pool) + prefix := deprecatedTriePrefix(desc) + sched := newHashScheduler(desc.HashFn, parallelDispatch, desc.TrieBucket, desc.Owner, pool) - rootHash, stack, err := traverse(r, prefix, desc.RootPath, sched, t, flush, outputs, stack) + rootHash, err := traverse(r, prefix, *desc.RootPath, sched, t, flush, outputs) if err != nil { - return stack, err + return err } if err := sched.sync(t.batch); err != nil { - return stack, err + return err } if desc.RootPath.Len() > 0 { if err := writeRootEdge(desc.RootPath, rootHash, sched, t.batch); err != nil { - return stack, err + return err } } - return stack, nil -} - -type dfsFrame struct { - oldPath trie.BitArray - left trie.BitArray - right trie.BitArray - value felt.Felt - leftHash felt.Felt - isLeaf bool - state traverseStackState -} - -type parsedNode struct { - value felt.Felt - left trie.BitArray - right trie.BitArray - isLeaf bool + return nil } +// traverse walks the deprecated trie rooted at oldPath in DFS order, writing +// the new-format equivalents into t.batch via sched. Returns the hash of +// the visited subtree — used by the caller to wire up parent binary nodes. func traverse( r db.KeyValueReader, prefix []byte, - start *trie.BitArray, + oldPath trie.BitArray, sched *hashScheduler, t *task, flush func(*task, chan<- task) error, outputs chan<- task, - stack []dfsFrame, -) (felt.Felt, []dfsFrame, error) { - stack = stack[:1] - stack[0] = dfsFrame{oldPath: *start} - var lastHash felt.Felt - for len(stack) > 0 { - top := &stack[len(stack)-1] - switch top.state { - case readNodeState: - parsed, err := readNode(r, prefix, &top.oldPath) - if err != nil { - return felt.Felt{}, stack, err - } - top.value = parsed.value - top.left = parsed.left - top.right = parsed.right - top.isLeaf = parsed.isLeaf - if top.isLeaf { - newPath := toNewPath(&top.oldPath) - if err := processLeaf(newPath, &top.value, sched, t.batch); err != nil { - return felt.Felt{}, stack, err - } - lastHash = top.value - stack = stack[:len(stack)-1] - } else { - top.state = leftSubtreeDoneState - stack = pushFrame(stack, top.left) - } - case leftSubtreeDoneState: - top.leftHash = lastHash - top.state = rightSubtreeDoneState - stack = pushFrame(stack, top.right) - case rightSubtreeDoneState: - newPath := toNewPath(&top.oldPath) - if err := processBinary( - newPath, - &top.left, - &top.right, - top.leftHash, - lastHash, - sched, - t.batch, - ); err != nil { - return felt.Felt{}, stack, err - } - lastHash = top.value - stack = stack[:len(stack)-1] +) (felt.Felt, error) { + parsed, err := readNode(r, prefix, &oldPath) + if err != nil { + return felt.Felt{}, err + } + t.nodes++ + + if parsed.isLeaf { + newPath := toNewPath(&oldPath) + if err := processLeaf(newPath, &parsed.value, sched, t.batch); err != nil { + return felt.Felt{}, err } if err := flush(t, outputs); err != nil { - return felt.Felt{}, stack, err + return felt.Felt{}, err } + return parsed.value, nil } - return lastHash, stack, nil -} -func pushFrame(stack []dfsFrame, oldPath trie.BitArray) []dfsFrame { - n := len(stack) - stack = stack[:n+1] - stack[n] = dfsFrame{oldPath: oldPath} - return stack + leftHash, err := traverse(r, prefix, parsed.left, sched, t, flush, outputs) + if err != nil { + return felt.Felt{}, err + } + rightHash, err := traverse(r, prefix, parsed.right, sched, t, flush, outputs) + if err != nil { + return felt.Felt{}, err + } + + newPath := toNewPath(&oldPath) + if err := processBinary( + newPath, &parsed.left, &parsed.right, leftHash, rightHash, sched, t.batch, + ); err != nil { + return felt.Felt{}, err + } + if err := flush(t, outputs); err != nil { + return felt.Felt{}, err + } + return parsed.value, nil } func encodeOldPath(path *trie.BitArray, dst []byte) int { @@ -236,9 +186,16 @@ func encodeOldPath(path *trie.BitArray, dst []byte) int { return int(activeBytes) + 1 } +type parsedNode struct { + value felt.Felt + left trie.BitArray + right trie.BitArray + isLeaf bool +} + // readNode loads the deprecated-format node at (prefix, oldPath) and returns -// its parsed fields. The caller assigns the parsed values into its own state -// (e.g. a DFS stack frame); this function does not mutate any input. +// its parsed fields. The caller owns the result; this function does not +// mutate any input. func readNode(r db.KeyValueReader, prefix []byte, oldPath *trie.BitArray) (parsedNode, error) { var arr [maxOldKeySize]byte n := copy(arr[:], prefix) @@ -327,11 +284,17 @@ func writeRootEdge( ) } -func oldTriePrefix(desc TrieDesc) []byte { - if desc.OldBucket == db.ContractStorage { - ownerFelt := felt.Felt(desc.Owner) - ownerBytes := ownerFelt.Bytes() - return desc.OldBucket.Key(ownerBytes[:]) +func deprecatedTriePrefix(desc TrieDesc) []byte { + switch desc.DeprecatedTrieBucket { + case db.ClassesTrie, db.StateTrie: + return desc.DeprecatedTrieBucket.Key() + case db.ContractStorage: + ownerBytes := desc.Owner.Bytes() + return desc.DeprecatedTrieBucket.Key(ownerBytes[:]) + default: + panic(fmt.Sprintf( + "unexpected deprecated trie bucket %v", + desc.DeprecatedTrieBucket, + )) } - return desc.OldBucket.Key() } diff --git a/migration/trie/trie.go b/migration/trie/trie.go index 9b00e55ea8..9fcfb5f6e0 100644 --- a/migration/trie/trie.go +++ b/migration/trie/trie.go @@ -97,6 +97,12 @@ func runMigration( return shouldRerun, err } + var allNodes uint64 + for _, d := range tries { + allNodes += uint64(d.NodeCount) + } + allTries := uint64(len(tries)) + src := pipeline.Source(func(yield func(TrieDesc) bool) { for _, d := range tries { if !yield(d) { @@ -108,7 +114,7 @@ func runMigration( committed := pipeline.New( ingested, 1, - newCommitter(logger, batchSem), + newCommitter(logger, batchSem, allTries, allNodes), ) _, wait := committed.Run(ctx) @@ -143,12 +149,12 @@ func wipeDeprecatedBuckets(database db.KeyValueRangeDeleter) error { } type TrieDesc struct { - OldBucket db.Bucket - NewBucket db.Bucket - Owner felt.Address - HashFn crypto.HashFn - NodeCount int - RootPath *trie.BitArray + DeprecatedTrieBucket db.Bucket + TrieBucket db.Bucket + Owner felt.Address + HashFn crypto.HashFn + NodeCount int + RootPath *trie.BitArray } func enumerateTries(r db.KeyValueReader) ([]TrieDesc, error) { @@ -188,88 +194,83 @@ func enumerateGlobalTrie( return TrieDesc{}, fmt.Errorf("opening iterator for bucket %v: %w", oldBucket, err) } defer it.Close() + it.First() - var rootPath *trie.BitArray - count := 0 - for valid := it.First(); valid; valid = it.Next() { - key := it.Key() - if len(key) == len(prefix) { - val, verr := it.Value() - if verr != nil { - return TrieDesc{}, fmt.Errorf("reading root path for bucket %v: %w", oldBucket, verr) - } - rp, perr := parseRootPath(val) - if perr != nil { - return TrieDesc{}, fmt.Errorf("parsing root path for bucket %v: %w", oldBucket, perr) - } - rootPath = rp - } else { - count++ - } + rootPath, count, err := scanTrie(it, prefix) + if err != nil { + return TrieDesc{}, fmt.Errorf("enumerating bucket %v: %w", oldBucket, err) } return TrieDesc{ - OldBucket: oldBucket, - NewBucket: newBucket, - HashFn: hashFn, - NodeCount: count, - RootPath: rootPath, + DeprecatedTrieBucket: oldBucket, + TrieBucket: newBucket, + HashFn: hashFn, + NodeCount: count, + RootPath: rootPath, }, nil } func enumerateStorageTries(r db.KeyValueReader) ([]TrieDesc, error) { - storagePrefix := db.ContractStorage.Key() - it, err := r.NewIterator(storagePrefix, true) + it, err := r.NewIterator(db.ContractStorage.Key(), true) if err != nil { return nil, fmt.Errorf("opening storage iterator: %w", err) } defer it.Close() + it.First() var descs []TrieDesc - for valid := it.First(); valid; valid = it.Valid() { + for it.Valid() { key := it.Key() if len(key) < 1+felt.Bytes { it.Next() continue } - ownerFelt := felt.FromBytes[felt.Felt](key[1 : 1+felt.Bytes]) - owner := felt.Address(ownerFelt) - ownerBytes := ownerFelt.Bytes() + owner := felt.FromBytes[felt.Address](key[1 : 1+felt.Bytes]) + ownerBytes := owner.Bytes() ownerPrefix := db.ContractStorage.Key(ownerBytes[:]) - var rootPath *trie.BitArray - count := 0 - for it.Valid() { - k := it.Key() - if !bytes.HasPrefix(k, ownerPrefix) { - break - } - if len(k) == len(ownerPrefix) { - val, verr := it.Value() - if verr != nil { - return nil, fmt.Errorf("reading root path for storage owner %s: %w", &ownerFelt, verr) - } - rp, perr := parseRootPath(val) - if perr != nil { - return nil, fmt.Errorf("parsing root path for storage owner %s: %w", &ownerFelt, perr) - } - rootPath = rp - } else { - count++ - } - it.Next() + rootPath, count, err := scanTrie(it, ownerPrefix) + if err != nil { + return nil, fmt.Errorf("enumerating storage owner %s: %w", &owner, err) } descs = append(descs, TrieDesc{ - OldBucket: db.ContractStorage, - NewBucket: db.ContractTrieStorage, - Owner: owner, - HashFn: crypto.Pedersen, - NodeCount: count, - RootPath: rootPath, + DeprecatedTrieBucket: db.ContractStorage, + TrieBucket: db.ContractTrieStorage, + Owner: owner, + HashFn: crypto.Pedersen, + NodeCount: count, + RootPath: rootPath, }) + // scanTrie leaves the iterator positioned past this owner's range. } return descs, nil } +func scanTrie(it db.Iterator, prefix []byte) (*trie.BitArray, int, error) { + var rootPath *trie.BitArray + count := 0 + for it.Valid() { + key := it.Key() + if !bytes.HasPrefix(key, prefix) { + return rootPath, count, nil + } + if len(key) == len(prefix) { + val, err := it.Value() + if err != nil { + return nil, 0, err + } + parsedRootPath, err := parseRootPath(val) + if err != nil { + return nil, 0, err + } + rootPath = parsedRootPath + } else { + count++ + } + it.Next() + } + return rootPath, count, nil +} + func parseRootPath(val []byte) (*trie.BitArray, error) { if len(val) == 0 { return nil, nil diff --git a/migration/trie/trie_test.go b/migration/trie/trie_test.go index 840c55aa2d..c432b1f528 100644 --- a/migration/trie/trie_test.go +++ b/migration/trie/trie_test.go @@ -239,18 +239,17 @@ func TestMigrationIsResumable(t *testing.T) { return perr })) classDesc := TrieDesc{ - OldBucket: db.ClassesTrie, - NewBucket: db.ClassTrie, - HashFn: crypto.Poseidon, - NodeCount: len(leaves), - RootPath: classRootPath, + DeprecatedTrieBucket: db.ClassesTrie, + TrieBucket: db.ClassTrie, + HashFn: crypto.Poseidon, + NodeCount: len(leaves), + RootPath: classRootPath, } pool := newHashWorkerPool() defer pool.close() t1 := &task{batch: partialDB.NewBatch()} - stack := make([]dfsFrame, 0, dfsStackCap) - _, err = migrateTrie(partialDB, classDesc, pool, t1, nopFlush, nil, stack) + err = migrateTrie(partialDB, classDesc, pool, t1, nopFlush, nil) require.NoError(t, err) require.NoError(t, t1.batch.Write()) @@ -269,7 +268,7 @@ func TestMigrationIsResumable(t *testing.T) { descsRemaining := collectTries(t, partialDB) foundClass := false for _, d := range descsRemaining { - if d.OldBucket == db.ClassesTrie { + if d.DeprecatedTrieBucket == db.ClassesTrie { foundClass = true } } @@ -408,18 +407,18 @@ func TestEnumerateTries_EmptyDBYieldsClassAndContractTries(t *testing.T) { memDB := memory.New() descs := collectTries(t, memDB) require.Len(t, descs, 2) - assert.Equal(t, db.ClassesTrie, descs[0].OldBucket) - assert.Equal(t, db.StateTrie, descs[1].OldBucket) + assert.Equal(t, db.ClassesTrie, descs[0].DeprecatedTrieBucket) + assert.Equal(t, db.StateTrie, descs[1].DeprecatedTrieBucket) } func TestEnumerateTries_GlobalTriesPresent(t *testing.T) { memDB := memory.New() descs := collectTries(t, memDB) hasClass := slices.ContainsFunc(descs, func(d TrieDesc) bool { - return d.OldBucket == db.ClassesTrie && d.NewBucket == db.ClassTrie + return d.DeprecatedTrieBucket == db.ClassesTrie && d.TrieBucket == db.ClassTrie }) hasContract := slices.ContainsFunc(descs, func(d TrieDesc) bool { - return d.OldBucket == db.StateTrie && d.NewBucket == db.ContractTrieContract + return d.DeprecatedTrieBucket == db.StateTrie && d.TrieBucket == db.ContractTrieContract }) assert.True(t, hasClass) assert.True(t, hasContract) @@ -440,8 +439,8 @@ func TestEnumerateTries_StorageTriesDiscovered(t *testing.T) { storageCount := 0 for _, d := range descs { - if d.OldBucket == db.ContractStorage { - assert.Equal(t, db.ContractTrieStorage, d.NewBucket) + if d.DeprecatedTrieBucket == db.ContractStorage { + assert.Equal(t, db.ContractTrieStorage, d.TrieBucket) storageCount++ } } @@ -458,7 +457,9 @@ func TestEnumerateTries_NodeCountIsCorrect(t *testing.T) { descs := collectTries(t, memDB) require.Len(t, descs, 3) - idx := slices.IndexFunc(descs, func(d TrieDesc) bool { return d.OldBucket == db.ContractStorage }) + idx := slices.IndexFunc(descs, func(d TrieDesc) bool { + return d.DeprecatedTrieBucket == db.ContractStorage + }) require.NotEqual(t, -1, idx) assert.Equal(t, 7, descs[idx].NodeCount) } @@ -474,7 +475,7 @@ func TestEnumerateTries_StorageTrieCountsPresent(t *testing.T) { descs := collectTries(t, memDB) var storageCounts []int for _, d := range descs { - if d.OldBucket == db.ContractStorage { + if d.DeprecatedTrieBucket == db.ContractStorage { storageCounts = append(storageCounts, d.NodeCount) } } @@ -492,7 +493,9 @@ func TestEnumerateTries_StorageTrieOwnerMatchesKey(t *testing.T) { descs := collectTries(t, memDB) require.Len(t, descs, 3) - idx := slices.IndexFunc(descs, func(d TrieDesc) bool { return d.OldBucket == db.ContractStorage }) + idx := slices.IndexFunc(descs, func(d TrieDesc) bool { + return d.DeprecatedTrieBucket == db.ContractStorage + }) require.NotEqual(t, -1, idx) assert.Equal(t, owner, descs[idx].Owner) } @@ -512,7 +515,7 @@ func TestEnumerateTries_MultipleOwnersOrdered(t *testing.T) { var storageCounts []int for _, d := range descs { - if d.OldBucket == db.ContractStorage { + if d.DeprecatedTrieBucket == db.ContractStorage { storageCounts = append(storageCounts, d.NodeCount) } } From 5bf2e19f82d6f9027e14b203e70d3fd5c1d330a1 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 20 May 2026 14:10:45 +0200 Subject: [PATCH 11/14] chore: test only public API --- migration/trie/codec_test.go | 163 -------------- migration/trie/hashworker_test.go | 151 ------------- migration/trie/trie_test.go | 350 +++++++++++------------------- 3 files changed, 131 insertions(+), 533 deletions(-) delete mode 100644 migration/trie/codec_test.go delete mode 100644 migration/trie/hashworker_test.go diff --git a/migration/trie/codec_test.go b/migration/trie/codec_test.go deleted file mode 100644 index 7f0e201861..0000000000 --- a/migration/trie/codec_test.go +++ /dev/null @@ -1,163 +0,0 @@ -package trie - -import ( - "bytes" - "fmt" - "testing" - - "github.com/NethermindEth/juno/core/crypto" - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" - "github.com/NethermindEth/juno/core/trie2/trienode" - "github.com/NethermindEth/juno/core/trie2/trieutils" - "github.com/NethermindEth/juno/db" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func makeOldPath(length uint8, val uint64) trie.BitArray { - var ba trie.BitArray - ba.SetUint64(length, val) - return ba -} - -func makeNewPath(length uint8, val uint64) trieutils.Path { - old := makeOldPath(length, val) - return toNewPath(&old) -} - -func TestToNewPath_PreservesLengthAndBits(t *testing.T) { - for _, tc := range []struct { - name string - len uint8 - val uint64 - }{ - {"zero length", 0, 0}, - {"single bit 0", 1, 0}, - {"single bit 1", 1, 1}, - {"8 bits", 8, 0xAB}, - {"251 bits", 251, 0xDEADBEEF}, - } { - t.Run(tc.name, func(t *testing.T) { - old := makeOldPath(tc.len, tc.val) - np := toNewPath(&old) - assert.Equal(t, tc.len, np.Len()) - assert.Equal(t, old.Bytes(), np.Bytes()) - }) - } -} - -func TestEncodeValueNode(t *testing.T) { - var v felt.Felt - v.SetUint64(0xCAFEBABE) - blob := encodeValueNode(&v) - assert.Equal(t, valueNodeBlobSize, len(blob)) - assert.Equal(t, v.Bytes(), blob) -} - -func TestEncodeBinaryNode(t *testing.T) { - var l, r felt.Felt - l.SetUint64(1) - r.SetUint64(2) - blob := encodeBinaryNode(&l, &r) - assert.Equal(t, binaryNodeBlobSize, len(blob)) - assert.Equal(t, binaryNodeTag, blob[0]) - lb := l.Bytes() - rb := r.Bytes() - assert.Equal(t, lb[:], blob[1:33]) - assert.Equal(t, rb[:], blob[33:65]) -} - -func TestEncodeEdgeNode(t *testing.T) { - for _, pathLen := range []uint8{0, 1, 10, 250} { - t.Run(fmt.Sprintf("pathLen=%d", pathLen), func(t *testing.T) { - var childHash felt.Felt - childHash.SetUint64(42) - seg := makeNewPath(pathLen, 0b101) - - var buf [edgeNodeMaxSize]byte - n := encodeEdgeNodeInto(buf[:], &childHash, &seg) - blob := buf[:n] - - require.Greater(t, n, 0) - assert.Equal(t, edgeNodeTag, blob[0]) - - var got felt.Felt - got.SetBytes(blob[1:33]) - assert.Equal(t, childHash, got) - - activeBytes := (int(pathLen) + 7) / 8 - assert.Equal(t, 1+felt.Bytes+activeBytes+1, n) - }) - } -} - -func TestComputeEdgeHash_ZeroLengthPath_ReturnChildHashUnchanged(t *testing.T) { - var childHash felt.Felt - childHash.SetUint64(99) - var path trieutils.Path - result := computeEdgeHash(&childHash, &path, crypto.Pedersen) - assert.True(t, result.Equal(&childHash)) -} - -func TestComputeEdgeHash_MatchesEdgeNodeHash(t *testing.T) { - for _, hashFn := range []crypto.HashFn{crypto.Pedersen, crypto.Poseidon} { - for _, pathLen := range []uint8{1, 10, 250} { - var childHash felt.Felt - childHash.SetUint64(123456) - seg := makeNewPath(pathLen, 0b1011) - - got := computeEdgeHash(&childHash, &seg, hashFn) - - hashNode := trienode.HashNode(childHash) - edge := &trienode.EdgeNode{Child: &hashNode, Path: &seg} - want := edge.Hash(hashFn) - - assert.True(t, got.Equal(&want), "pathLen=%d", pathLen) - } - } -} - -func TestCompressedSegment_Length(t *testing.T) { - for _, tc := range []struct { - parentLen uint8 - segLen uint8 - }{ - {0, 0}, - {0, 5}, - {10, 20}, - {100, 50}, - } { - childLen := tc.parentLen + 1 + tc.segLen - child := makeOldPath(childLen, 0b111) - seg := compressedSegment(&child, tc.parentLen) - assert.Equal(t, tc.segLen, seg.Len(), "parentLen=%d segLen=%d", tc.parentLen, tc.segLen) - } -} - -func TestOldTriePrefix_GlobalTrie(t *testing.T) { - desc := TrieDesc{DeprecatedTrieBucket: db.ClassesTrie, Owner: felt.Address{}} - prefix := deprecatedTriePrefix(desc) - assert.Equal(t, []byte{byte(db.ClassesTrie)}, prefix) -} - -func TestOldTriePrefix_StorageTrie(t *testing.T) { - var ownerFelt felt.Felt - ownerFelt.SetUint64(42) - owner := felt.Address(ownerFelt) - desc := TrieDesc{DeprecatedTrieBucket: db.ContractStorage, Owner: owner} - prefix := deprecatedTriePrefix(desc) - assert.Equal(t, byte(db.ContractStorage), prefix[0]) - assert.Equal(t, 1+felt.Bytes, len(prefix)) -} - -func TestParseOldPath_RoundTrip(t *testing.T) { - original := makeOldPath(17, 0b10101) - var buf bytes.Buffer - _, err := original.Write(&buf) - require.NoError(t, err) - var parsed trie.BitArray - require.NoError(t, parsed.UnmarshalBinary(buf.Bytes())) - assert.Equal(t, original.Len(), parsed.Len()) - assert.Equal(t, original.Bytes(), parsed.Bytes()) -} diff --git a/migration/trie/hashworker_test.go b/migration/trie/hashworker_test.go deleted file mode 100644 index 0c89893f76..0000000000 --- a/migration/trie/hashworker_test.go +++ /dev/null @@ -1,151 +0,0 @@ -package trie - -import ( - "testing" - - "github.com/NethermindEth/juno/core/crypto" - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie2/trieutils" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/db/memory" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func makeSimpleJob(parent *trieutils.Path, leftChild, rightChild *felt.Felt) *edgeHashJob { - return &edgeHashJob{ - leftChildHash: *leftChild, - rightChildHash: *rightChild, - parentPath: *parent, - } -} - -func TestHashScheduler_InlineWritesNodeToBatch(t *testing.T) { - memDB := memory.New() - batch := memDB.NewBatch() - - sched := newHashScheduler(crypto.Pedersen, false, db.ClassTrie, felt.Address{}, nil) - parent := makeNewPath(3, 0b101) - left := felt.NewFromUint64[felt.Felt](1) - right := felt.NewFromUint64[felt.Felt](2) - - require.NoError(t, sched.schedule(makeSimpleJob(&parent, left, right), batch)) - require.NoError(t, batch.Write()) - - blob := encodeBinaryNode(left, right) - val, err := trieutils.GetNodeByPath(memDB, db.ClassTrie, &felt.Address{}, &parent, false) - require.NoError(t, err) - assert.Equal(t, blob[:], val) -} - -func TestHashScheduler_BatchedMatchesInlinePedersen(t *testing.T) { - testBatchedMatchesInline(t, crypto.Pedersen, 200) -} - -func TestHashScheduler_BatchedMatchesInlinePoseidon(t *testing.T) { - testBatchedMatchesInline(t, crypto.Poseidon, 200) -} - -func testBatchedMatchesInline(t *testing.T, hashFn crypto.HashFn, n int) { - t.Helper() - - pool := newHashWorkerPool() - t.Cleanup(pool.close) - - jobs := make([]edgeHashJob, n) - for i := range n { - l := felt.NewFromUint64[felt.Felt](uint64(i*2 + 1)) - r := felt.NewFromUint64[felt.Felt](uint64(i*2 + 2)) - path := makeNewPath(uint8(i%250), uint64(i)) - jobs[i] = *makeSimpleJob(&path, l, r) - } - - inlineDB := memory.New() - { - batch := inlineDB.NewBatch() - sched := newHashScheduler(hashFn, false, db.ClassTrie, felt.Address{}, nil) - for i := range jobs { - require.NoError(t, sched.schedule(&jobs[i], batch)) - } - require.NoError(t, batch.Write()) - } - - batchDB := memory.New() - { - batch := batchDB.NewBatch() - sched := newHashScheduler(hashFn, true, db.ClassTrie, felt.Address{}, pool) - for i := range jobs { - require.NoError(t, sched.schedule(&jobs[i], batch)) - } - require.NoError(t, sched.sync(batch)) - require.NoError(t, batch.Write()) - } - - for i := range jobs { - job := &jobs[i] - inlineVal, err := trieutils.GetNodeByPath( - inlineDB, - db.ClassTrie, - &felt.Address{}, - &job.parentPath, - false, - ) - require.NoError(t, err, "path %d missing in inline", i) - batchVal, err := trieutils.GetNodeByPath( - batchDB, - db.ClassTrie, - &felt.Address{}, - &job.parentPath, - false, - ) - require.NoError(t, err, "path %d missing in batched", i) - assert.Equal(t, inlineVal, batchVal, "mismatch at path %d", i) - } -} - -func TestHashScheduler_AutoFlushAtBatchSize(t *testing.T) { - pool := newHashWorkerPool() - t.Cleanup(pool.close) - - memDB := memory.New() - batch := memDB.NewBatch() - - sched := newHashScheduler(crypto.Pedersen, true, db.ClassTrie, felt.Address{}, pool) - for i := range parallelHashBatchSize { - path := makeNewPath(251, uint64(i+1)) - leftChildHash := felt.NewFromUint64[felt.Felt](uint64(i + 1)) - rightChildHash := felt.NewFromUint64[felt.Felt](uint64(i + 2)) - job := makeSimpleJob(&path, leftChildHash, rightChildHash) - require.NoError(t, sched.schedule(job, batch)) - } - // After filling a full batch, jobs are handed off to pool; local slice is reset - assert.Empty(t, sched.jobs, "jobs should be empty after auto-flush") - // Drain in-flight batch before pool.close() to avoid send-on-closed-channel - require.NoError(t, sched.sync(batch)) -} - -func TestHashScheduler_SingleJobDispatchesCorrectly(t *testing.T) { - pool := newHashWorkerPool() - t.Cleanup(pool.close) - - memDB := memory.New() - batch := memDB.NewBatch() - - sched := newHashScheduler(crypto.Pedersen, true, db.ClassTrie, felt.Address{}, pool) - parent := makeNewPath(4, 0b1010) - l, r := felt.NewFromUint64[felt.Felt](7), felt.NewFromUint64[felt.Felt](13) - job := makeSimpleJob(&parent, l, r) - - require.NoError(t, sched.schedule(job, batch)) - require.NoError(t, sched.sync(batch)) - require.NoError(t, batch.Write()) - - blob := encodeBinaryNode(l, r) - val, err := trieutils.GetNodeByPath(memDB, db.ClassTrie, &felt.Address{}, &parent, false) - require.NoError(t, err) - assert.Equal(t, blob[:], val) -} - -func TestHashScheduler_LargeBatchDispatchesCorrectly(t *testing.T) { - testBatchedMatchesInline(t, crypto.Pedersen, parallelHashBatchSize*3) -} diff --git a/migration/trie/trie_test.go b/migration/trie/trie_test.go index c432b1f528..00758af76f 100644 --- a/migration/trie/trie_test.go +++ b/migration/trie/trie_test.go @@ -1,10 +1,8 @@ -package trie +package trie_test import ( - "bytes" "context" "math/rand" - "slices" "testing" "github.com/NethermindEth/juno/core/crypto" @@ -16,6 +14,7 @@ import ( "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" + trielib "github.com/NethermindEth/juno/migration/trie" "github.com/NethermindEth/juno/utils/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -23,9 +22,6 @@ import ( type leafMap map[felt.Felt]felt.Felt -func nopLogger() log.StructuredLogger { return log.NewNopZapLogger() } -func nopFlush(*task, chan<- task) error { return nil } - type trieCase struct { name string oldBucket db.Bucket @@ -101,40 +97,24 @@ var transcoderCases = []struct { name string leaves leafMap }{ - // --- trivial structural cases --- {"empty trie", nil}, {"single leaf", leafMap{ felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](100), }}, - - // --- two-leaf cases probing specific split depths --- - - // Keys 2 and 3 differ only in the last path bit (bit 250 of felt = path bit 250). - // Root edge spans 250 bits; binary node is at maximum depth. {"deep split", leafMap{ felt.FromUint64[felt.Felt](2): felt.FromUint64[felt.Felt](10), felt.FromUint64[felt.Felt](3): felt.FromUint64[felt.Felt](20), }}, - - // One key < 2^250 (trie path bit[0]=0), one key ≥ 2^250 (trie path bit[0]=1). - // Root is a binary node — rootPath.Len()==0, no root edge written. - // This path is never hit by sequential small-integer keys. {"left right split", leafMap{ felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](10), felt.FromBytes[felt.Felt]([]byte{0x04}): felt.FromUint64[felt.Felt](20), // 2^250 }}, - - // --- four leaves covering all 2-bit path prefixes (00, 01, 10, 11) --- - // Root is a binary node; left and right subtrees each contain a binary node. - // Tests two levels of binary processing with rootPath.Len()==0. {"full depth 2 tree", leafMap{ felt.FromUint64[felt.Felt](1): felt.FromUint64[felt.Felt](10), // 00... - felt.FromBytes[felt.Felt]([]byte{0x02}): felt.FromUint64[felt.Felt](20), // 01... (2^249) - felt.FromBytes[felt.Felt]([]byte{0x04}): felt.FromUint64[felt.Felt](30), // 10... (2^250) - felt.FromBytes[felt.Felt]([]byte{0x06}): felt.FromUint64[felt.Felt](40), // 11... (2^250+2^249) + felt.FromBytes[felt.Felt]([]byte{0x02}): felt.FromUint64[felt.Felt](20), // 01... + felt.FromBytes[felt.Felt]([]byte{0x04}): felt.FromUint64[felt.Felt](30), // 10... + felt.FromBytes[felt.Felt]([]byte{0x06}): felt.FromUint64[felt.Felt](40), // 11... }}, - - // --- larger sequential case for basic load --- {"hundred sequential leaves", func() leafMap { leaves := make(leafMap, 100) for i := 1; i <= 100; i++ { @@ -142,17 +122,13 @@ var transcoderCases = []struct { } return leaves }()}, - - // --- random leaves spanning the full 251-bit space --- - // Keys are evenly distributed across all trie depths, exercising every code - // path: balanced binary nodes, varied edge lengths, and batch flush thresholds. {"random 1000 leaves", randomLeaves(1000, 42)}, } func TestMigrate_FreshDBIsNoOp(t *testing.T) { memDB := memory.New() - state, err := (&Migrator{}).Migrate(context.Background(), memDB, nil, nopLogger()) + state, err := (&trielib.Migrator{}).Migrate(context.Background(), memDB, nil, log.NewNopZapLogger()) require.NoError(t, err) assert.Nil( t, @@ -165,17 +141,17 @@ func TestMigrate_RunsWhenOldDataPresent(t *testing.T) { leaves := randomLeaves(100, 7) memDB := buildFullDB(t, leaves) - needed, err := needsMigration(memDB) - require.NoError(t, err) - require.True(t, needed, "precondition: DB has old-format data") + require.True(t, bucketHasKeys(t, memDB, db.ClassesTrie), "precondition: DB has old-format data") - state, err := (&Migrator{}).Migrate(context.Background(), memDB, nil, nopLogger()) + state, err := (&trielib.Migrator{}).Migrate(context.Background(), memDB, nil, log.NewNopZapLogger()) require.NoError(t, err) assert.Nil(t, state, "completed migration must return nil intermediate state") - stillNeeded, err := needsMigration(memDB) - require.NoError(t, err) - assert.False(t, stillNeeded, "old-format buckets should be empty after migration") + for _, bucket := range []db.Bucket{db.ClassesTrie, db.StateTrie, db.ContractStorage} { + assert.False(t, + bucketHasKeys(t, memDB, bucket), + "old-format bucket %v should be empty after migration", bucket) + } } // TestMigrationEndToEnd verifies that the migration produces byte-for-byte @@ -205,7 +181,9 @@ func TestMigrationEndToEnd(t *testing.T) { migratedDB := memory.New() buildDeprecatedTrie(t, migratedDB, c.leaves, c.tc.buildOldFn, prefix) - _, err := runMigration(context.Background(), migratedDB, nopLogger()) + _, err := (&trielib.Migrator{}).Migrate( + context.Background(), migratedDB, nil, log.NewNopZapLogger(), + ) require.NoError(t, err) nativeDB := memory.New() @@ -219,67 +197,33 @@ func TestMigrationEndToEnd(t *testing.T) { } } +// TestMigrationIsResumable verifies that re-running migration over a DB +// whose class-trie destination root is already present skips the class +// migration and only finishes the contract trie. The "partial state" is +// faked by copying the reference DB's new-format class-trie keys into the +// partial DB before running migration. func TestMigrationIsResumable(t *testing.T) { leaves := randomLeaves(1000, 42) // Reference: full migration from scratch. refDB := buildFullDB(t, leaves) - _, err := runMigration(context.Background(), refDB, nopLogger()) + _, err := (&trielib.Migrator{}).Migrate(context.Background(), refDB, nil, log.NewNopZapLogger()) require.NoError(t, err) // Partial DB: both tries in old format initially. partialDB := buildFullDB(t, leaves) - // Manually migrate only the class trie to simulate a mid-run interruption. - classPrefix := db.ClassesTrie.Key() - var classRootPath *trie.BitArray - require.NoError(t, partialDB.Get(classPrefix, func(val []byte) error { - var perr error - classRootPath, perr = parseRootPath(val) - return perr - })) - classDesc := TrieDesc{ - DeprecatedTrieBucket: db.ClassesTrie, - TrieBucket: db.ClassTrie, - HashFn: crypto.Poseidon, - NodeCount: len(leaves), - RootPath: classRootPath, - } - - pool := newHashWorkerPool() - defer pool.close() - t1 := &task{batch: partialDB.NewBatch()} - err = migrateTrie(partialDB, classDesc, pool, t1, nopFlush, nil) - require.NoError(t, err) - require.NoError(t, t1.batch.Write()) - - done, err := hasDestRoot(partialDB, db.ClassTrie, &classDesc.Owner) - require.NoError(t, err) - assert.True(t, done, "class trie destination root should be present after partial migration") - - needed, err := needsMigration(partialDB) - require.NoError(t, err) - assert.True(t, needed, "migration should still be needed after partial work") - - // enumerateTries yields every discovered trie unfiltered; the - // already-done short-circuit lives in the ingestor now. We assert that - // the class trie *is* yielded here, and rely on the end-to-end equality - // check below to confirm the ingestor's idempotency guard skips it. - descsRemaining := collectTries(t, partialDB) - foundClass := false - for _, d := range descsRemaining { - if d.DeprecatedTrieBucket == db.ClassesTrie { - foundClass = true - } + // Fake a prior successful class-trie migration by copying refDB's + // new-format class-trie keys directly into partialDB. + refClassKeys := allKeysUnder(t, refDB, db.ClassTrie) + require.NotEmpty(t, refClassKeys, "reference class trie should be non-empty") + for k, v := range refClassKeys { + require.NoError(t, partialDB.Put([]byte(k), v)) } - assert.True( - t, - foundClass, - "enumerateTries should yield class trie even when destination root is present", - ) - // Resume: migration completes the remaining contract trie. - _, err = runMigration(context.Background(), partialDB, nopLogger()) + // Resume: migration should skip the class trie (its dest root is present) + // and complete only the contract trie. + _, err = (&trielib.Migrator{}).Migrate(context.Background(), partialDB, nil, log.NewNopZapLogger()) require.NoError(t, err) // Final state must match the reference full-run output for every new-format bucket. @@ -367,158 +311,126 @@ func buildTrie( require.NoError(t, batch.Write()) } -func allKeysUnder(t *testing.T, r db.KeyValueReader, bucket db.Bucket) map[string][]byte { - t.Helper() - prefix := bucket.Key() - iter, err := r.NewIterator(prefix, true) - require.NoError(t, err) - defer iter.Close() - out := make(map[string][]byte) - for ok := iter.First(); ok; ok = iter.Next() { - val, err := iter.Value() - require.NoError(t, err) - out[string(iter.Key())] = val +// TestMigrationMultiStorageOwners exercises enumerateStorageTries across +// multiple owners (scanTrie's prefix-leave path) and keeps all 4 ingestor +// workers busy by giving them 7 tries to chew through (2 global + 5 storage). +func TestMigrationMultiStorageOwners(t *testing.T) { + leaves := randomLeaves(50, 7) + + migratedDB := memory.New() + buildDeprecatedTrie(t, migratedDB, leaves, trie.NewTriePoseidon, db.ClassesTrie.Key()) + buildDeprecatedTrie(t, migratedDB, leaves, trie.NewTriePedersen, db.StateTrie.Key()) + + owners := []felt.Address{ + felt.FromUint64[felt.Address](1), + felt.FromUint64[felt.Address](2), + felt.FromUint64[felt.Address](3), + felt.FromUint64[felt.Address](42), + felt.FromUint64[felt.Address](999), + } + for _, owner := range owners { + ownerFelt := felt.Felt(owner) + ownerBytes := ownerFelt.Bytes() + buildDeprecatedTrie(t, migratedDB, leaves, trie.NewTriePedersen, + db.ContractStorage.Key(ownerBytes[:])) } - return out -} -func insertFakeStorageNodes(t *testing.T, database db.KeyValueStore, owner felt.Address, n int) { - t.Helper() - ownerFelt := felt.Felt(owner) - ownerBytes := ownerFelt.Bytes() - ownerPrefix := db.ContractStorage.Key(ownerBytes[:]) + _, err := (&trielib.Migrator{}).Migrate(context.Background(), migratedDB, nil, log.NewNopZapLogger()) + require.NoError(t, err) - require.NoError(t, database.Put(ownerPrefix, []byte{0})) + // Per-owner native build → assert every native key is present (with the + // same value) under the merged migrated view. + migratedAll := allKeysUnder(t, migratedDB, db.ContractTrieStorage) + for _, owner := range owners { + nativeDB := memory.New() + id := trieutils.NewContractStorageTrieID(felt.StateRootHash(felt.One), owner) + buildTrie(t, nativeDB, leaves, id, crypto.Pedersen, db.ContractTrieStorage) + for k, v := range allKeysUnder(t, nativeDB, db.ContractTrieStorage) { + gotV, ok := migratedAll[k] + require.True(t, ok, "owner %v missing key", owner) + assert.Equal(t, v, gotV, "owner %v value differs at key", owner) + } + } - for i := range n { - key := append(bytes.Clone(ownerPrefix), 8, byte(i)) - require.NoError(t, database.Put(key, []byte{0xFF})) + // Old buckets fully drained. + for _, bucket := range []db.Bucket{db.ClassesTrie, db.StateTrie, db.ContractStorage} { + assert.False(t, bucketHasKeys(t, migratedDB, bucket), + "old bucket %v should be drained", bucket) } } -func collectTries(t *testing.T, r db.KeyValueReader) []TrieDesc { - t.Helper() - descs, err := enumerateTries(r) - require.NoError(t, err) - return descs -} +// TestMigrationIdempotent verifies that a successful migration is a no-op on +// a second run: needsMigration sees the wiped deprecated buckets and returns +// early without touching the migrated state. +func TestMigrationIsNoopOnSecondRun(t *testing.T) { + leaves := randomLeaves(100, 11) + memDB := buildFullDB(t, leaves) -func TestEnumerateTries_EmptyDBYieldsClassAndContractTries(t *testing.T) { - memDB := memory.New() - descs := collectTries(t, memDB) - require.Len(t, descs, 2) - assert.Equal(t, db.ClassesTrie, descs[0].DeprecatedTrieBucket) - assert.Equal(t, db.StateTrie, descs[1].DeprecatedTrieBucket) -} + state, err := (&trielib.Migrator{}).Migrate(context.Background(), memDB, nil, log.NewNopZapLogger()) + require.NoError(t, err) + require.Nil(t, state) + snapshot := snapshotAllBuckets(t, memDB, + db.ClassTrie, db.ContractTrieContract, db.ContractTrieStorage) -func TestEnumerateTries_GlobalTriesPresent(t *testing.T) { - memDB := memory.New() - descs := collectTries(t, memDB) - hasClass := slices.ContainsFunc(descs, func(d TrieDesc) bool { - return d.DeprecatedTrieBucket == db.ClassesTrie && d.TrieBucket == db.ClassTrie - }) - hasContract := slices.ContainsFunc(descs, func(d TrieDesc) bool { - return d.DeprecatedTrieBucket == db.StateTrie && d.TrieBucket == db.ContractTrieContract - }) - assert.True(t, hasClass) - assert.True(t, hasContract) + state, err = (&trielib.Migrator{}).Migrate(context.Background(), memDB, nil, log.NewNopZapLogger()) + require.NoError(t, err) + require.Nil(t, state) + require.Equal(t, snapshot, + snapshotAllBuckets(t, memDB, + db.ClassTrie, db.ContractTrieContract, db.ContractTrieStorage), + "second Migrate call must not change state") } -func TestEnumerateTries_StorageTriesDiscovered(t *testing.T) { - memDB := memory.New() - var owners [3]felt.Address - for i := range owners { - var f felt.Felt - f.SetUint64(uint64(i + 1)) - owners[i] = felt.Address(f) - insertFakeStorageNodes(t, memDB, owners[i], 5) - } +// TestMigrationCancelledContext verifies that a pre-cancelled ctx surfaces +// context.Canceled with the shouldRerun sentinel, and that a fresh ctx +// completes the migration normally afterwards. +func TestMigrationCancelledContext(t *testing.T) { + leaves := randomLeaves(100, 13) + memDB := buildFullDB(t, leaves) - descs := collectTries(t, memDB) - require.Len(t, descs, 5) + ctx, cancel := context.WithCancel(context.Background()) + cancel() - storageCount := 0 - for _, d := range descs { - if d.DeprecatedTrieBucket == db.ContractStorage { - assert.Equal(t, db.ContractTrieStorage, d.TrieBucket) - storageCount++ - } - } - assert.Equal(t, 3, storageCount) -} + state, err := (&trielib.Migrator{}).Migrate(ctx, memDB, nil, log.NewNopZapLogger()) + require.ErrorIs(t, err, context.Canceled) + require.NotNil(t, state, "shouldRerun sentinel must not be nil") + require.Empty(t, state, "shouldRerun is a non-nil empty slice") -func TestEnumerateTries_NodeCountIsCorrect(t *testing.T) { - memDB := memory.New() - var ownerFelt felt.Felt - ownerFelt.SetUint64(99) - owner := felt.Address(ownerFelt) - insertFakeStorageNodes(t, memDB, owner, 7) - - descs := collectTries(t, memDB) - require.Len(t, descs, 3) - - idx := slices.IndexFunc(descs, func(d TrieDesc) bool { - return d.DeprecatedTrieBucket == db.ContractStorage - }) - require.NotEqual(t, -1, idx) - assert.Equal(t, 7, descs[idx].NodeCount) + state, err = (&trielib.Migrator{}).Migrate(context.Background(), memDB, nil, log.NewNopZapLogger()) + require.NoError(t, err) + require.Nil(t, state) } -func TestEnumerateTries_StorageTrieCountsPresent(t *testing.T) { - memDB := memory.New() - for i, n := range []int{3, 7, 1} { - var f felt.Felt - f.SetUint64(uint64(i + 1)) - insertFakeStorageNodes(t, memDB, felt.Address(f), n) - } - - descs := collectTries(t, memDB) - var storageCounts []int - for _, d := range descs { - if d.DeprecatedTrieBucket == db.ContractStorage { - storageCounts = append(storageCounts, d.NodeCount) +func snapshotAllBuckets(t *testing.T, r db.KeyValueReader, buckets ...db.Bucket) map[string][]byte { + t.Helper() + out := make(map[string][]byte) + for _, b := range buckets { + for k, v := range allKeysUnder(t, r, b) { + out[k] = v } } - slices.Sort(storageCounts) - assert.Equal(t, []int{1, 3, 7}, storageCounts) -} - -func TestEnumerateTries_StorageTrieOwnerMatchesKey(t *testing.T) { - memDB := memory.New() - var ownerFelt felt.Felt - ownerFelt.SetUint64(12345) - owner := felt.Address(ownerFelt) - insertFakeStorageNodes(t, memDB, owner, 3) - - descs := collectTries(t, memDB) - require.Len(t, descs, 3) - - idx := slices.IndexFunc(descs, func(d TrieDesc) bool { - return d.DeprecatedTrieBucket == db.ContractStorage - }) - require.NotEqual(t, -1, idx) - assert.Equal(t, owner, descs[idx].Owner) + return out } -func TestEnumerateTries_MultipleOwnersOrdered(t *testing.T) { - memDB := memory.New() - owners := make([]felt.Address, 5) - for i := range owners { - var f felt.Felt - f.SetUint64(uint64(i + 1)) - owners[i] = felt.Address(f) - insertFakeStorageNodes(t, memDB, owners[i], i+1) +func allKeysUnder(t *testing.T, r db.KeyValueReader, bucket db.Bucket) map[string][]byte { + t.Helper() + prefix := bucket.Key() + iter, err := r.NewIterator(prefix, true) + require.NoError(t, err) + defer iter.Close() + out := make(map[string][]byte) + for ok := iter.First(); ok; ok = iter.Next() { + val, err := iter.Value() + require.NoError(t, err) + out[string(iter.Key())] = val } + return out +} - descs := collectTries(t, memDB) - require.Len(t, descs, 7) - - var storageCounts []int - for _, d := range descs { - if d.DeprecatedTrieBucket == db.ContractStorage { - storageCounts = append(storageCounts, d.NodeCount) - } - } - slices.Sort(storageCounts) - assert.Equal(t, []int{1, 2, 3, 4, 5}, storageCounts) +func bucketHasKeys(t *testing.T, r db.KeyValueReader, bucket db.Bucket) bool { + t.Helper() + it, err := r.NewIterator(bucket.Key(), true) + require.NoError(t, err) + defer it.Close() + return it.First() } From 4aac6ceb788ff0caa72be6f719de2c293220fa2a Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 20 May 2026 14:42:30 +0200 Subject: [PATCH 12/14] chore: self review --- core/trie2/trieutils/accessors.go | 31 ----------------- migration/trie/codec.go | 51 ++++++++++++++++++++++++--- migration/trie/hashpool.go | 8 +++-- migration/trie/hashworker.go | 8 ++--- migration/trie/ingestor.go | 57 ++++++++++++++++--------------- migration/trie/trie.go | 33 ++++++------------ migration/trie/trie_test.go | 28 ++++++++++++--- 7 files changed, 120 insertions(+), 96 deletions(-) diff --git a/core/trie2/trieutils/accessors.go b/core/trie2/trieutils/accessors.go index bbfccc91ce..08f96c7166 100644 --- a/core/trie2/trieutils/accessors.go +++ b/core/trie2/trieutils/accessors.go @@ -39,37 +39,6 @@ func WriteNodeByPath( return w.Put(nodeKeyByPath(bucket, owner, path, isLeaf), blob) } -// MaxNodeKeySize is the maximum byte length of a new-format trie node key: -// 1 (prefix) + 32 (owner, optional) + 1 (nodeType) + MaxBitArraySize (path). -const MaxNodeKeySize = 1 + 32 + 1 + MaxBitArraySize - -// EncodeNodeKey writes the node key into dst and returns the number of bytes written. -// dst must have at least MaxNodeKeySize bytes of capacity. -func EncodeNodeKey(dst []byte, bucket db.Bucket, owner *felt.Address, path *Path, isLeaf bool) int { - n := 0 - dst[n] = byte(bucket) - n++ - - if !felt.IsZero(owner) { - ownerBytes := owner.Bytes() - copy(dst[n:], ownerBytes[:]) - n += 32 - } - - if isLeaf { - dst[n] = leaf.Byte() - } else { - dst[n] = nonLeaf.Byte() - } - n++ - - pathBytes := path.EncodedBytes() - copy(dst[n:], pathBytes) - n += len(pathBytes) - - return n -} - func DeleteNodeByPath( w db.KeyValueWriter, bucket db.Bucket, diff --git a/migration/trie/codec.go b/migration/trie/codec.go index 77e9264113..d9534d9fb4 100644 --- a/migration/trie/codec.go +++ b/migration/trie/codec.go @@ -5,6 +5,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" ) const ( @@ -15,6 +16,18 @@ const ( binaryNodeBlobSize = 1 + 2*felt.Bytes edgeNodeMinSize = 1 + felt.Bytes + 1 edgeNodeMaxSize = 1 + felt.Bytes + trieutils.MaxBitArraySize + + // nonLeafByte and leafByte mirror the on-disk byte values of trieutils' + // unexported leafType enum (core/trie2/trieutils/types.go). They're part + // of the persisted trie format so they cannot change; replicated locally + // rather than exposed publicly from trieutils. TestMigrationEndToEnd's + // byte-for-byte equality with a natively-built trie2 guards against drift. + nonLeafByte byte = 1 + leafByte byte = 2 + + // maxNodeKeySize is the maximum byte length of a new-format trie node key: + // 1 (prefix) + 32 (owner, optional) + 1 (nodeType) + MaxBitArraySize (path). + maxNodeKeySize = 1 + 32 + 1 + trieutils.MaxBitArraySize ) func toNewPath(old *trie.BitArray) trieutils.Path { @@ -24,10 +37,6 @@ func toNewPath(old *trie.BitArray) trieutils.Path { return p } -func encodeValueNode(value *felt.Felt) [valueNodeBlobSize]byte { - return value.Bytes() -} - func encodeBinaryNode(leftEdgeHash, rightEdgeHash *felt.Felt) [binaryNodeBlobSize]byte { var blob [binaryNodeBlobSize]byte blob[0] = binaryNodeTag @@ -63,3 +72,37 @@ func compressedSegment(childFullPath *trie.BitArray, parentLen uint8) trieutils. seg.LSBs(childFullPath, parentLen+1) return toNewPath(&seg) } + +// encodeNodeKey writes the node key into dst and returns the number of bytes +// written. dst must have at least maxNodeKeySize bytes of capacity. Mirrors +// trieutils.nodeKeyByPath without the per-call allocation. +func encodeNodeKey( + dst []byte, + bucket db.Bucket, + owner *felt.Address, + path *trieutils.Path, + isLeaf bool, +) int { + n := 0 + dst[n] = byte(bucket) + n++ + + if !felt.IsZero(owner) { + ownerBytes := owner.Bytes() + copy(dst[n:], ownerBytes[:]) + n += 32 + } + + if isLeaf { + dst[n] = leafByte + } else { + dst[n] = nonLeafByte + } + n++ + + pathBytes := path.EncodedBytes() + copy(dst[n:], pathBytes) + n += len(pathBytes) + + return n +} diff --git a/migration/trie/hashpool.go b/migration/trie/hashpool.go index 89710a33ee..ca08b338c1 100644 --- a/migration/trie/hashpool.go +++ b/migration/trie/hashpool.go @@ -20,9 +20,11 @@ type hashWorkerPool struct { } func newHashWorkerPool() *hashWorkerPool { - n := 4 - p := &hashWorkerPool{work: make(chan hashWork, n*2), n: n} - for range n { + p := &hashWorkerPool{ + work: make(chan hashWork, IngestorCount*2), + n: IngestorCount, + } + for range IngestorCount { go func() { for w := range p.work { for i := range w.jobs { diff --git a/migration/trie/hashworker.go b/migration/trie/hashworker.go index 44aba742cd..9c385e222e 100644 --- a/migration/trie/hashworker.go +++ b/migration/trie/hashworker.go @@ -134,8 +134,8 @@ func (s *hashScheduler) writeBinaryAndEdges( rightEdge *felt.Felt, batch db.Batch, ) error { - var buf [trieutils.MaxNodeKeySize + binaryNodeBlobSize]byte - keyLen := trieutils.EncodeNodeKey(buf[:], s.bucket, &s.owner, &job.parentPath, false) + var buf [maxNodeKeySize + binaryNodeBlobSize]byte + keyLen := encodeNodeKey(buf[:], s.bucket, &s.owner, &job.parentPath, false) blob := encodeBinaryNode(leftEdge, rightEdge) copy(buf[keyLen:], blob[:]) if err := batch.Put(buf[:keyLen], buf[keyLen:keyLen+binaryNodeBlobSize]); err != nil { @@ -160,8 +160,8 @@ func (s *hashScheduler) writeEdge( } var edgePath trieutils.Path edgePath.AppendBit(parentPath, bit) - var ebuf [trieutils.MaxNodeKeySize + edgeNodeMaxSize]byte - kl := trieutils.EncodeNodeKey(ebuf[:], s.bucket, &s.owner, &edgePath, false) + var ebuf [maxNodeKeySize + edgeNodeMaxSize]byte + kl := encodeNodeKey(ebuf[:], s.bucket, &s.owner, &edgePath, false) blob := encodeEdgeNodeInto(ebuf[kl:], childHash, seg) return batch.Put(ebuf[:kl], ebuf[kl:kl+blob]) } diff --git a/migration/trie/ingestor.go b/migration/trie/ingestor.go index 16702ae443..83fdf70a42 100644 --- a/migration/trie/ingestor.go +++ b/migration/trie/ingestor.go @@ -46,9 +46,9 @@ func newIngestor( } func (i *ingestor) Run(index int, desc TrieDesc, outputs chan<- task) error { - done, err := hasDestRoot(i.database, desc.TrieBucket, &desc.Owner) + done, err := rootProcessed(i.database, desc.TrieBucket, &desc.Owner) if err != nil { - return fmt.Errorf("hasDestRoot(%v, %x): %w", desc.TrieBucket, desc.Owner, err) + return fmt.Errorf("rootProcessed(%v, %x): %w", desc.TrieBucket, desc.Owner, err) } t := &i.tasks[index] @@ -59,7 +59,7 @@ func (i *ingestor) Run(index int, desc TrieDesc, outputs chan<- task) error { return i.flush(t, outputs) } - if err := migrateTrie(i.database, desc, i.pool, t, i.flush, outputs); err != nil { + if err := i.migrateTrie(t, desc, outputs); err != nil { return err } @@ -97,22 +97,15 @@ func (i *ingestor) flush(t *task, outputs chan<- task) error { // trie into t.batch. It walks the deprecated trie via DFS, emits value / // binary / edge nodes through the hashScheduler, and calls flush at each // step to rotate the batch when it hits target size. -func migrateTrie( - r db.KeyValueReader, - desc TrieDesc, - pool *hashWorkerPool, - t *task, - flush func(*task, chan<- task) error, - outputs chan<- task, -) error { - if desc.RootPath == nil { +func (i *ingestor) migrateTrie(t *task, desc TrieDesc, outputs chan<- task) error { + if desc.NodeCount == 0 { return nil } parallelDispatch := desc.NodeCount >= SmallTrieThreshold prefix := deprecatedTriePrefix(desc) - sched := newHashScheduler(desc.HashFn, parallelDispatch, desc.TrieBucket, desc.Owner, pool) + sched := newHashScheduler(desc.HashFn, parallelDispatch, desc.TrieBucket, desc.Owner, i.pool) - rootHash, err := traverse(r, prefix, *desc.RootPath, sched, t, flush, outputs) + rootHash, err := i.traverse(t, outputs, prefix, *desc.RootPath, sched) if err != nil { return err } @@ -130,16 +123,14 @@ func migrateTrie( // traverse walks the deprecated trie rooted at oldPath in DFS order, writing // the new-format equivalents into t.batch via sched. Returns the hash of // the visited subtree — used by the caller to wire up parent binary nodes. -func traverse( - r db.KeyValueReader, +func (i *ingestor) traverse( + t *task, + outputs chan<- task, prefix []byte, oldPath trie.BitArray, sched *hashScheduler, - t *task, - flush func(*task, chan<- task) error, - outputs chan<- task, ) (felt.Felt, error) { - parsed, err := readNode(r, prefix, &oldPath) + parsed, err := readNode(i.database, prefix, &oldPath) if err != nil { return felt.Felt{}, err } @@ -150,17 +141,17 @@ func traverse( if err := processLeaf(newPath, &parsed.value, sched, t.batch); err != nil { return felt.Felt{}, err } - if err := flush(t, outputs); err != nil { + if err := i.flush(t, outputs); err != nil { return felt.Felt{}, err } return parsed.value, nil } - leftHash, err := traverse(r, prefix, parsed.left, sched, t, flush, outputs) + leftHash, err := i.traverse(t, outputs, prefix, parsed.left, sched) if err != nil { return felt.Felt{}, err } - rightHash, err := traverse(r, prefix, parsed.right, sched, t, flush, outputs) + rightHash, err := i.traverse(t, outputs, prefix, parsed.right, sched) if err != nil { return felt.Felt{}, err } @@ -171,12 +162,24 @@ func traverse( ); err != nil { return felt.Felt{}, err } - if err := flush(t, outputs); err != nil { + if err := i.flush(t, outputs); err != nil { return felt.Felt{}, err } return parsed.value, nil } +func rootProcessed(r db.KeyValueReader, newBucket db.Bucket, owner *felt.Address) (bool, error) { + var emptyPath trieutils.Path + var buf [maxNodeKeySize]byte + + n := encodeNodeKey(buf[:], newBucket, owner, &emptyPath, false) + if exists, err := r.Has(buf[:n]); err != nil || exists { + return exists, err + } + n = encodeNodeKey(buf[:], newBucket, owner, &emptyPath, true) + return r.Has(buf[:n]) +} + func encodeOldPath(path *trie.BitArray, dst []byte) int { pathLen := path.Len() b := path.Bytes() @@ -240,9 +243,9 @@ func processLeaf( sched *hashScheduler, batch db.Batch, ) error { - var buf [trieutils.MaxNodeKeySize + valueNodeBlobSize]byte - keyLen := trieutils.EncodeNodeKey(buf[:], sched.bucket, &sched.owner, &path, true) - blob := encodeValueNode(value) + var buf [maxNodeKeySize + valueNodeBlobSize]byte + keyLen := encodeNodeKey(buf[:], sched.bucket, &sched.owner, &path, true) + blob := value.Bytes() copy(buf[keyLen:], blob[:]) return batch.Put(buf[:keyLen], buf[keyLen:keyLen+valueNodeBlobSize]) } diff --git a/migration/trie/trie.go b/migration/trie/trie.go index 9fcfb5f6e0..9356b10782 100644 --- a/migration/trie/trie.go +++ b/migration/trie/trie.go @@ -11,7 +11,6 @@ import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" - "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/dbutils" "github.com/NethermindEth/juno/migration" @@ -205,7 +204,7 @@ func enumerateGlobalTrie( TrieBucket: newBucket, HashFn: hashFn, NodeCount: count, - RootPath: rootPath, + RootPath: &rootPath, }, nil } @@ -238,15 +237,15 @@ func enumerateStorageTries(r db.KeyValueReader) ([]TrieDesc, error) { Owner: owner, HashFn: crypto.Pedersen, NodeCount: count, - RootPath: rootPath, + RootPath: &rootPath, }) // scanTrie leaves the iterator positioned past this owner's range. } return descs, nil } -func scanTrie(it db.Iterator, prefix []byte) (*trie.BitArray, int, error) { - var rootPath *trie.BitArray +func scanTrie(it db.Iterator, prefix []byte) (trie.BitArray, int, error) { + var rootPath trie.BitArray count := 0 for it.Valid() { key := it.Key() @@ -256,11 +255,11 @@ func scanTrie(it db.Iterator, prefix []byte) (*trie.BitArray, int, error) { if len(key) == len(prefix) { val, err := it.Value() if err != nil { - return nil, 0, err + return trie.BitArray{}, 0, err } parsedRootPath, err := parseRootPath(val) if err != nil { - return nil, 0, err + return trie.BitArray{}, 0, err } rootPath = parsedRootPath } else { @@ -271,25 +270,13 @@ func scanTrie(it db.Iterator, prefix []byte) (*trie.BitArray, int, error) { return rootPath, count, nil } -func parseRootPath(val []byte) (*trie.BitArray, error) { +func parseRootPath(val []byte) (trie.BitArray, error) { if len(val) == 0 { - return nil, nil + return trie.BitArray{}, nil } var ba trie.BitArray if err := ba.UnmarshalBinary(val); err != nil { - return nil, err - } - return &ba, nil -} - -func hasDestRoot(r db.KeyValueReader, newBucket db.Bucket, owner *felt.Address) (bool, error) { - var emptyPath trieutils.Path - var buf [trieutils.MaxNodeKeySize]byte - - n := trieutils.EncodeNodeKey(buf[:], newBucket, owner, &emptyPath, false) - if exists, err := r.Has(buf[:n]); err != nil || exists { - return exists, err + return trie.BitArray{}, err } - n = trieutils.EncodeNodeKey(buf[:], newBucket, owner, &emptyPath, true) - return r.Has(buf[:n]) + return ba, nil } diff --git a/migration/trie/trie_test.go b/migration/trie/trie_test.go index 00758af76f..05b125aacc 100644 --- a/migration/trie/trie_test.go +++ b/migration/trie/trie_test.go @@ -128,7 +128,12 @@ var transcoderCases = []struct { func TestMigrate_FreshDBIsNoOp(t *testing.T) { memDB := memory.New() - state, err := (&trielib.Migrator{}).Migrate(context.Background(), memDB, nil, log.NewNopZapLogger()) + state, err := (&trielib.Migrator{}).Migrate( + context.Background(), + memDB, + nil, + log.NewNopZapLogger(), + ) require.NoError(t, err) assert.Nil( t, @@ -143,7 +148,12 @@ func TestMigrate_RunsWhenOldDataPresent(t *testing.T) { require.True(t, bucketHasKeys(t, memDB, db.ClassesTrie), "precondition: DB has old-format data") - state, err := (&trielib.Migrator{}).Migrate(context.Background(), memDB, nil, log.NewNopZapLogger()) + state, err := (&trielib.Migrator{}).Migrate( + context.Background(), + memDB, + nil, + log.NewNopZapLogger(), + ) require.NoError(t, err) assert.Nil(t, state, "completed migration must return nil intermediate state") @@ -335,7 +345,12 @@ func TestMigrationMultiStorageOwners(t *testing.T) { db.ContractStorage.Key(ownerBytes[:])) } - _, err := (&trielib.Migrator{}).Migrate(context.Background(), migratedDB, nil, log.NewNopZapLogger()) + _, err := (&trielib.Migrator{}).Migrate( + context.Background(), + migratedDB, + nil, + log.NewNopZapLogger(), + ) require.NoError(t, err) // Per-owner native build → assert every native key is present (with the @@ -366,7 +381,12 @@ func TestMigrationIsNoopOnSecondRun(t *testing.T) { leaves := randomLeaves(100, 11) memDB := buildFullDB(t, leaves) - state, err := (&trielib.Migrator{}).Migrate(context.Background(), memDB, nil, log.NewNopZapLogger()) + state, err := (&trielib.Migrator{}).Migrate( + context.Background(), + memDB, + nil, + log.NewNopZapLogger(), + ) require.NoError(t, err) require.Nil(t, state) snapshot := snapshotAllBuckets(t, memDB, From 2a2fffb6f3fe61026323d42425082eef6b0b13e9 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 20 May 2026 23:30:12 +0200 Subject: [PATCH 13/14] chore: add comments --- migration/trie/codec.go | 99 +++++++++------- migration/trie/hashworker.go | 13 +-- migration/trie/ingestor.go | 133 +++++++++++++++------ migration/trie/trie.go | 41 +++++-- migration/trie/trie_test.go | 219 ++++++++++++++++------------------- 5 files changed, 291 insertions(+), 214 deletions(-) diff --git a/migration/trie/codec.go b/migration/trie/codec.go index d9534d9fb4..33b12253c3 100644 --- a/migration/trie/codec.go +++ b/migration/trie/codec.go @@ -16,38 +16,30 @@ const ( binaryNodeBlobSize = 1 + 2*felt.Bytes edgeNodeMinSize = 1 + felt.Bytes + 1 edgeNodeMaxSize = 1 + felt.Bytes + trieutils.MaxBitArraySize + maxNodeKeySize = 1 + felt.Bytes + 1 + trieutils.MaxBitArraySize - // nonLeafByte and leafByte mirror the on-disk byte values of trieutils' - // unexported leafType enum (core/trie2/trieutils/types.go). They're part - // of the persisted trie format so they cannot change; replicated locally - // rather than exposed publicly from trieutils. TestMigrationEndToEnd's - // byte-for-byte equality with a natively-built trie2 guards against drift. nonLeafByte byte = 1 leafByte byte = 2 - - // maxNodeKeySize is the maximum byte length of a new-format trie node key: - // 1 (prefix) + 32 (owner, optional) + 1 (nodeType) + MaxBitArraySize (path). - maxNodeKeySize = 1 + 32 + 1 + trieutils.MaxBitArraySize ) -func toNewPath(old *trie.BitArray) trieutils.Path { - b := old.Bytes() - var p trieutils.Path - p.SetBytes(old.Len(), b[:]) - return p -} +// +// --- Encoding related helpers --- +// -func encodeBinaryNode(leftEdgeHash, rightEdgeHash *felt.Felt) [binaryNodeBlobSize]byte { - var blob [binaryNodeBlobSize]byte - blob[0] = binaryNodeTag +// encodeBinaryNode writes a binary-node blob into dst. +// dst must have at least binaryNodeBlobSize bytes of capacity. +func encodeBinaryNode(dst []byte, leftEdgeHash, rightEdgeHash *felt.Felt) int { + dst[0] = binaryNodeTag lb := leftEdgeHash.Bytes() rb := rightEdgeHash.Bytes() - copy(blob[1:], lb[:]) - copy(blob[1+felt.Bytes:], rb[:]) - return blob + copy(dst[1:], lb[:]) + copy(dst[1+felt.Bytes:], rb[:]) + return binaryNodeBlobSize } -func encodeEdgeNodeInto(dst []byte, childHash *felt.Felt, pathSeg *trieutils.Path) int { +// encodeEdgeNode writes an edge-node blob into dst. +// dst must have at least edgeNodeMaxSize bytes of capacity. +func encodeEdgeNode(dst []byte, childHash *felt.Felt, pathSeg *trieutils.Path) int { encoded := pathSeg.EncodedBytes() dst[0] = edgeNodeTag h := childHash.Bytes() @@ -56,26 +48,6 @@ func encodeEdgeNodeInto(dst []byte, childHash *felt.Felt, pathSeg *trieutils.Pat return 1 + felt.Bytes + len(encoded) } -func computeEdgeHash(childHash *felt.Felt, path *trieutils.Path, hashFn crypto.HashFn) felt.Felt { - if path.Len() == 0 { - return *childHash - } - pathFelt := path.Felt() - h := hashFn(childHash, &pathFelt) - lenFelt := felt.FromUint64[felt.Felt](uint64(path.Len())) - h.Add(&h, &lenFelt) - return h -} - -func compressedSegment(childFullPath *trie.BitArray, parentLen uint8) trieutils.Path { - var seg trie.BitArray - seg.LSBs(childFullPath, parentLen+1) - return toNewPath(&seg) -} - -// encodeNodeKey writes the node key into dst and returns the number of bytes -// written. dst must have at least maxNodeKeySize bytes of capacity. Mirrors -// trieutils.nodeKeyByPath without the per-call allocation. func encodeNodeKey( dst []byte, bucket db.Bucket, @@ -106,3 +78,46 @@ func encodeNodeKey( return n } + +// +// --- Path related helpers --- +// + +func parseDeprecatedPath(val []byte) (trie.BitArray, error) { + if len(val) == 0 { + return trie.BitArray{}, nil + } + var ba trie.BitArray + if err := ba.UnmarshalBinary(val); err != nil { + return trie.BitArray{}, err + } + return ba, nil +} + +func toNewPath(old *trie.BitArray) trieutils.Path { + b := old.Bytes() + var p trieutils.Path + p.SetBytes(old.Len(), b[:]) + return p +} + +func compressedSegment(childFullPath *trie.BitArray, parentLen uint8) trieutils.Path { + var seg trie.BitArray + seg.LSBs(childFullPath, parentLen+1) + return toNewPath(&seg) +} + +// +// --- Hash related helpers --- +// + +func computeEdgeHash(childHash *felt.Felt, path *trieutils.Path, hashFn crypto.HashFn) felt.Felt { + if path.Len() == 0 { + return *childHash + } + pathFelt := path.Felt() + h := hashFn(childHash, &pathFelt) + lenFelt := felt.FromUint64[felt.Felt](uint64(path.Len())) + h.Add(&h, &lenFelt) + return h +} diff --git a/migration/trie/hashworker.go b/migration/trie/hashworker.go index 9c385e222e..0fafa08920 100644 --- a/migration/trie/hashworker.go +++ b/migration/trie/hashworker.go @@ -136,9 +136,8 @@ func (s *hashScheduler) writeBinaryAndEdges( ) error { var buf [maxNodeKeySize + binaryNodeBlobSize]byte keyLen := encodeNodeKey(buf[:], s.bucket, &s.owner, &job.parentPath, false) - blob := encodeBinaryNode(leftEdge, rightEdge) - copy(buf[keyLen:], blob[:]) - if err := batch.Put(buf[:keyLen], buf[keyLen:keyLen+binaryNodeBlobSize]); err != nil { + blobLen := encodeBinaryNode(buf[keyLen:], leftEdge, rightEdge) + if err := batch.Put(buf[:keyLen], buf[keyLen:keyLen+blobLen]); err != nil { return err } @@ -160,8 +159,8 @@ func (s *hashScheduler) writeEdge( } var edgePath trieutils.Path edgePath.AppendBit(parentPath, bit) - var ebuf [maxNodeKeySize + edgeNodeMaxSize]byte - kl := encodeNodeKey(ebuf[:], s.bucket, &s.owner, &edgePath, false) - blob := encodeEdgeNodeInto(ebuf[kl:], childHash, seg) - return batch.Put(ebuf[:kl], ebuf[kl:kl+blob]) + var buf [maxNodeKeySize + edgeNodeMaxSize]byte + keyLen := encodeNodeKey(buf[:], s.bucket, &s.owner, &edgePath, false) + blobLen := encodeEdgeNode(buf[keyLen:], childHash, seg) + return batch.Put(buf[:keyLen], buf[keyLen:keyLen+blobLen]) } diff --git a/migration/trie/ingestor.go b/migration/trie/ingestor.go index 83fdf70a42..ad26c39b60 100644 --- a/migration/trie/ingestor.go +++ b/migration/trie/ingestor.go @@ -11,8 +11,6 @@ import ( "github.com/NethermindEth/juno/migration/semaphore" ) -const maxOldKeySize = 1 + 32 + 1 + 32 - type task struct { batch db.Batch tries int @@ -76,8 +74,6 @@ func (i *ingestor) Done(index int, outputs chan<- task) error { return nil } -// flush rotates t.batch if it's hit target size: sends the current batch -// downstream and acquires a fresh one. Mutates t in place. func (i *ingestor) flush(t *task, outputs chan<- task) error { if t.batch.Size() < targetBatchByteSize { return nil @@ -93,10 +89,90 @@ func (i *ingestor) flush(t *task, outputs chan<- task) error { return nil } -// migrateTrie writes the new-format representation of a single deprecated -// trie into t.batch. It walks the deprecated trie via DFS, emits value / -// binary / edge nodes through the hashScheduler, and calls flush at each -// step to rotate the batch when it hits target size. +// migrateTrie reads one deprecated trie and writes its equivalent into the +// new layout. Three things differ between the formats: how nodes are keyed +// on disk, how nodes are encoded, and how path compression is expressed. +// +// On-disk keying +// -------------- +// Both layouts share a common prefix; only the suffix differs: +// +// common (both) suffix +// ───────────── ───────────────────────────────────────── +// bucket [|| owner] → path-length-byte || path-bytes (deprecated) +// → nodeType-byte || path-length-byte || path-bytes +// (new) +// +// The owner is present only for storage tries. The new layout's extra +// nodeType byte splits leaves from internal nodes into two index slices +// within the same bucket — the new-state lookups use this to short-circuit +// between leaf reads and internal-node traversals. +// +// Node encoding +// ------------- +// Both layouts are byte streams. The deprecated format keeps each node +// self-contained — internal binary nodes embed the compressed paths to +// their children inline: +// +// leaf value +// binary value || left-child-path || right-child-path +// [|| left-hash || right-hash, optional cache, ignored here] +// +// "value" is the node's own Starknet trie hash, or the stored value when +// the node is a leaf. +// +// The new format gives each node an explicit type tag and moves path +// compression into separate edge nodes: +// +// value value +// binary 0x01 || left-edge-hash || right-edge-hash +// edge 0x02 || child-hash || encoded-path-segment +// +// Path compression +// ---------------- +// This is the key structural change. The deprecated format compresses +// paths inside the parent binary node (via its embedded child-path +// fields). The new format moves compression into dedicated edge nodes +// sitting between binary nodes and their children: +// +// deprecated: binary ──────── child-path ────────► child +// new: binary ──► edge ──► child +// +// The deprecated root marker — a single entry at the bare bucket prefix +// recording the root's path — disappears in the new layout. Whatever the +// deprecated root embedded becomes either a direct binary/leaf at the +// empty path or, when the deprecated root path was itself non-empty, an +// edge node at the empty path that points "down" to the real root. +// +// Traversal +// --------- +// The migrator walks the deprecated trie depth-first, decoding one node at +// a time. A leaf becomes a value node at the same path. An internal binary +// node, after both subtrees have been visited, becomes a binary node plus +// up to two edge nodes (one per non-empty child segment). If the trie's +// stored root path is itself non-empty — meaning the deprecated root +// embeds a compression — a single edge node at the empty path is written +// after the traversal completes, replacing the root marker. +// +// Hashes +// ------ +// Starknet trie hashes: +// +// leaf value +// binary hashFn(left-edge-hash, right-edge-hash) +// edge hashFn(child-hash, path-segment-as-felt) + segment-length +// +// Zero-length edges short-circuit to the bare child-hash — the convention +// for absent edges. Class tries hash with Poseidon; contract and storage +// tries with Pedersen. +// +// For small tries every edge hash is computed inline. Above the threshold, +// edge-hash jobs are batched and dispatched to a worker pool for parallel +// computation; the scheduler preserves the original job order so the +// persisted bytes are byte-identical to a natively-built trie2. +// +// In-flight batches flush at target size; cancellation is observed at +// every flush and every channel send. func (i *ingestor) migrateTrie(t *task, desc TrieDesc, outputs chan<- task) error { if desc.NodeCount == 0 { return nil @@ -113,16 +189,13 @@ func (i *ingestor) migrateTrie(t *task, desc TrieDesc, outputs chan<- task) erro return err } if desc.RootPath.Len() > 0 { - if err := writeRootEdge(desc.RootPath, rootHash, sched, t.batch); err != nil { + if err := writeRootEdgeNode(desc.RootPath, rootHash, sched, t.batch); err != nil { return err } } return nil } -// traverse walks the deprecated trie rooted at oldPath in DFS order, writing -// the new-format equivalents into t.batch via sched. Returns the hash of -// the visited subtree — used by the caller to wire up parent binary nodes. func (i *ingestor) traverse( t *task, outputs chan<- task, @@ -138,7 +211,7 @@ func (i *ingestor) traverse( if parsed.isLeaf { newPath := toNewPath(&oldPath) - if err := processLeaf(newPath, &parsed.value, sched, t.batch); err != nil { + if err := writeLeafNode(newPath, &parsed.value, sched, t.batch); err != nil { return felt.Felt{}, err } if err := i.flush(t, outputs); err != nil { @@ -196,48 +269,36 @@ type parsedNode struct { isLeaf bool } -// readNode loads the deprecated-format node at (prefix, oldPath) and returns -// its parsed fields. The caller owns the result; this function does not -// mutate any input. func readNode(r db.KeyValueReader, prefix []byte, oldPath *trie.BitArray) (parsedNode, error) { - var arr [maxOldKeySize]byte + var arr [maxNodeKeySize]byte n := copy(arr[:], prefix) n += encodeOldPath(oldPath, arr[n:]) var node parsedNode - err := r.Get(arr[:n], func(val []byte) error { - var perr error - node, perr = parseNodeData(val) - return perr - }) + err := r.Get(arr[:n], node.UnmarshalBinary) return node, err } -// parseNodeData decodes a deprecated-format node's raw bytes: -// felt(value) [ BitArray(left) BitArray(right) [ felt felt ] ] -// The trailing left/right hashes are ignored — the migrator re-derives hashes -// itself — so only the fields it actually needs are returned. -func parseNodeData(data []byte) (parsedNode, error) { - var n parsedNode +func (n *parsedNode) UnmarshalBinary(data []byte) error { if len(data) < felt.Bytes { - return n, fmt.Errorf("trie: node data too short (%d bytes)", len(data)) + return fmt.Errorf("trie: node data too short (%d bytes)", len(data)) } n.value = felt.FromBytes[felt.Felt](data[:felt.Bytes]) data = data[felt.Bytes:] if len(data) == 0 { n.isLeaf = true - return n, nil + return nil } if err := n.left.UnmarshalBinary(data); err != nil { - return n, fmt.Errorf("trie: unmarshalling left path: %w", err) + return fmt.Errorf("trie: unmarshalling left path: %w", err) } data = data[n.left.EncodedLen():] if err := n.right.UnmarshalBinary(data); err != nil { - return n, fmt.Errorf("trie: unmarshalling right path: %w", err) + return fmt.Errorf("trie: unmarshalling right path: %w", err) } - return n, nil + return nil } -func processLeaf( +func writeLeafNode( path trieutils.Path, value *felt.Felt, sched *hashScheduler, @@ -268,7 +329,7 @@ func processBinary( }, batch) } -func writeRootEdge( +func writeRootEdgeNode( rootPath *trie.BitArray, childHash felt.Felt, sched *hashScheduler, @@ -276,7 +337,7 @@ func writeRootEdge( ) error { seg := toNewPath(rootPath) var buf [edgeNodeMaxSize]byte - n := encodeEdgeNodeInto(buf[:], &childHash, &seg) + n := encodeEdgeNode(buf[:], &childHash, &seg) return trieutils.WriteNodeByPath( batch, sched.bucket, diff --git a/migration/trie/trie.go b/migration/trie/trie.go index 9356b10782..415ca2a5d6 100644 --- a/migration/trie/trie.go +++ b/migration/trie/trie.go @@ -35,6 +35,34 @@ var ( var deprecatedTrieBuckets = []db.Bucket{db.ClassesTrie, db.StateTrie, db.ContractStorage} +// Migrator converts every deprecated Starknet trie on disk into the +// equivalent trie2 layout used by the new state: +// +// ClassesTrie ─→ ClassTrie (Poseidon) +// StateTrie ─→ ContractTrieContract (Pedersen) +// ContractStorage ─→ ContractTrieStorage (Pedersen, per contract owner) +// +// Each deprecated trie is enumerated (its bucket holds one root-path entry +// plus N node entries), then walked in DFS order; every visited node is +// re-encoded into the new format and written to its destination bucket in +// the same batch. After every trie completes successfully, the three +// deprecated buckets are wiped via DeleteRange. +// +// The pipeline runs IngestorCount worker goroutines, each pulling one trie +// at a time from the enumeration source, plus a single committer that +// flushes filled batches to disk. A semaphore caps in-flight batches at +// IngestorCount * 2. See migrateTrie for the per-trie traversal. +// +// Re-run safe: every trie's first action checks for its new-format root +// key; if present, the trie is treated as already migrated and skipped. +// A subsequent run after a crash picks up where the previous one stopped — +// partially migrated tries either have a root key (skipped on the next +// pass) or don't (re-migrated from scratch; the deprecated source data is +// still present because the trailing wipe runs only on full success). +// +// Cancellation: every flush and every channel send checks ctx.Done. On +// cancel, Migrate returns the shouldRerun sentinel with ctx.Err(); the +// migration runner re-invokes on the next process start. type Migrator struct{} var _ migration.Migration = (*Migrator)(nil) @@ -257,7 +285,7 @@ func scanTrie(it db.Iterator, prefix []byte) (trie.BitArray, int, error) { if err != nil { return trie.BitArray{}, 0, err } - parsedRootPath, err := parseRootPath(val) + parsedRootPath, err := parseDeprecatedPath(val) if err != nil { return trie.BitArray{}, 0, err } @@ -269,14 +297,3 @@ func scanTrie(it db.Iterator, prefix []byte) (trie.BitArray, int, error) { } return rootPath, count, nil } - -func parseRootPath(val []byte) (trie.BitArray, error) { - if len(val) == 0 { - return trie.BitArray{}, nil - } - var ba trie.BitArray - if err := ba.UnmarshalBinary(val); err != nil { - return trie.BitArray{}, err - } - return ba, nil -} diff --git a/migration/trie/trie_test.go b/migration/trie/trie_test.go index 05b125aacc..c2d08b9df0 100644 --- a/migration/trie/trie_test.go +++ b/migration/trie/trie_test.go @@ -2,7 +2,6 @@ package trie_test import ( "context" - "math/rand" "testing" "github.com/NethermindEth/juno/core/crypto" @@ -27,8 +26,8 @@ type trieCase struct { oldBucket db.Bucket newBucket db.Bucket owner felt.Address - oldBuildPrefix func(owner felt.Address) []byte - newTrieID func(owner felt.Address) trieutils.TrieID + oldBuildPrefix func(owner *felt.Address) []byte + newTrieID func(owner *felt.Address) trieutils.TrieID hashFn crypto.HashFn //nolint:staticcheck // Necessary for old state buildOldFn func(db.IndexedBatch, []byte, uint8) (*trie.Trie, error) @@ -39,8 +38,8 @@ var trieCases = []trieCase{ name: "ClassTrie", oldBucket: db.ClassesTrie, newBucket: db.ClassTrie, - oldBuildPrefix: func(_ felt.Address) []byte { return []byte{byte(db.ClassesTrie)} }, - newTrieID: func(_ felt.Address) trieutils.TrieID { + oldBuildPrefix: func(_ *felt.Address) []byte { return []byte{byte(db.ClassesTrie)} }, + newTrieID: func(_ *felt.Address) trieutils.TrieID { return trieutils.NewClassTrieID(felt.StateRootHash(felt.One)) }, hashFn: crypto.Poseidon, @@ -50,8 +49,8 @@ var trieCases = []trieCase{ name: "ContractTrie", oldBucket: db.StateTrie, newBucket: db.ContractTrieContract, - oldBuildPrefix: func(_ felt.Address) []byte { return []byte{byte(db.StateTrie)} }, - newTrieID: func(_ felt.Address) trieutils.TrieID { + oldBuildPrefix: func(_ *felt.Address) []byte { return []byte{byte(db.StateTrie)} }, + newTrieID: func(_ *felt.Address) trieutils.TrieID { return trieutils.NewContractTrieID(felt.StateRootHash(felt.One)) }, hashFn: crypto.Pedersen, @@ -62,37 +61,18 @@ var trieCases = []trieCase{ oldBucket: db.ContractStorage, newBucket: db.ContractTrieStorage, owner: felt.FromUint64[felt.Address](42), - oldBuildPrefix: func(owner felt.Address) []byte { - ownerFelt := felt.Felt(owner) - ownerBytes := ownerFelt.Bytes() + oldBuildPrefix: func(owner *felt.Address) []byte { + ownerBytes := owner.Bytes() return db.ContractStorage.Key(ownerBytes[:]) }, - newTrieID: func(owner felt.Address) trieutils.TrieID { - return trieutils.NewContractStorageTrieID(felt.StateRootHash(felt.One), owner) + newTrieID: func(owner *felt.Address) trieutils.TrieID { + return trieutils.NewContractStorageTrieID(felt.StateRootHash(felt.One), *owner) }, hashFn: crypto.Pedersen, buildOldFn: trie.NewTriePedersen, }, } -// randomLeaves generates n distinct leaf key-value pairs using a fixed seed, -// with keys spread across the full 251-bit felt range for structural variety. -func randomLeaves(n int, seed int64) leafMap { - rng := rand.New(rand.NewSource(seed)) - leaves := make(leafMap, n) - var kb, vb [32]byte - for len(leaves) < n { - rng.Read(kb[:]) - rng.Read(vb[:]) - // Clear the top 5 bits so all keys are safely below the StarkNet prime (~2^251+δ). - kb[0] &= 0x07 - k := felt.FromBytes[felt.Felt](kb[:]) - v := felt.FromBytes[felt.Felt](vb[:]) - leaves[k] = v - } - return leaves -} - var transcoderCases = []struct { name string leaves leafMap @@ -122,7 +102,7 @@ var transcoderCases = []struct { } return leaves }()}, - {"random 1000 leaves", randomLeaves(1000, 42)}, + {"random 1000 leaves", randomLeaves(1000)}, } func TestMigrate_FreshDBIsNoOp(t *testing.T) { @@ -143,7 +123,7 @@ func TestMigrate_FreshDBIsNoOp(t *testing.T) { } func TestMigrate_RunsWhenOldDataPresent(t *testing.T) { - leaves := randomLeaves(100, 7) + leaves := randomLeaves(100) memDB := buildFullDB(t, leaves) require.True(t, bucketHasKeys(t, memDB, db.ClassesTrie), "precondition: DB has old-format data") @@ -187,7 +167,7 @@ func TestMigrationEndToEnd(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { - prefix := c.tc.oldBuildPrefix(c.tc.owner) + prefix := c.tc.oldBuildPrefix(&c.tc.owner) migratedDB := memory.New() buildDeprecatedTrie(t, migratedDB, c.leaves, c.tc.buildOldFn, prefix) @@ -198,7 +178,7 @@ func TestMigrationEndToEnd(t *testing.T) { nativeDB := memory.New() buildTrie(t, nativeDB, c.leaves, - c.tc.newTrieID(c.tc.owner), c.tc.hashFn, c.tc.newBucket) + c.tc.newTrieID(&c.tc.owner), c.tc.hashFn, c.tc.newBucket) assert.Equal(t, allKeysUnder(t, nativeDB, c.tc.newBucket), @@ -213,7 +193,7 @@ func TestMigrationEndToEnd(t *testing.T) { // faked by copying the reference DB's new-format class-trie keys into the // partial DB before running migration. func TestMigrationIsResumable(t *testing.T) { - leaves := randomLeaves(1000, 42) + leaves := randomLeaves(1000) // Reference: full migration from scratch. refDB := buildFullDB(t, leaves) @@ -245,87 +225,11 @@ func TestMigrationIsResumable(t *testing.T) { } } -// buildFullDB creates an old-format DB populated with a class, a contract, and -// one storage trie, all built from the same leaf set. -func buildFullDB(t *testing.T, leaves leafMap) db.KeyValueStore { - t.Helper() - database := memory.New() - - buildDeprecatedTrie(t, database, leaves, trie.NewTriePoseidon, db.ClassesTrie.Key()) - buildDeprecatedTrie(t, database, leaves, trie.NewTriePedersen, db.StateTrie.Key()) - - var ownerFelt felt.Felt - ownerFelt.SetUint64(42) - ownerBytes := ownerFelt.Bytes() - storagePrefix := db.ContractStorage.Key(ownerBytes[:]) - buildDeprecatedTrie(t, database, leaves, trie.NewTriePedersen, storagePrefix) - - return database -} - -func buildDeprecatedTrie( - t *testing.T, - database db.KeyValueStore, - leaves leafMap, - //nolint:staticcheck // Necessary for old state - trieFn func(db.IndexedBatch, []byte, uint8) (*trie.Trie, error), - prefix []byte, -) felt.Felt { - t.Helper() - //nolint:staticcheck // Necessary for old state - txn := database.NewIndexedBatch() - tr, err := trieFn(txn, prefix, 251) - require.NoError(t, err) - for key, value := range leaves { - _, err := tr.Put(&key, &value) - require.NoError(t, err) - } - root, err := tr.Root() - require.NoError(t, err) - require.NoError(t, tr.Commit()) - require.NoError(t, txn.Write()) - return root -} - -// buildTrie builds a trie2 natively from leaves and persists it to kvStore. -// newBucket distinguishes class trie (db.ClassTrie) from contract/storage tries — -// it controls which Update argument the NodeSet is passed as. -func buildTrie( - t *testing.T, - kvStore db.KeyValueStore, - leaves leafMap, - id trieutils.TrieID, - hashFn crypto.HashFn, - newBucket db.Bucket, -) { - t.Helper() - rawDB := rawdb.New(kvStore) - tr, err := trie2.New(id, 251, hashFn, rawDB) - require.NoError(t, err) - for key, value := range leaves { - require.NoError(t, tr.Update(&key, &value)) - } - root, nodes := tr.Commit() - if nodes == nil { - return // empty trie — nothing to persist - } - mergeSet := trienode.NewMergeNodeSet(nodes) - var zero felt.StateRootHash - stateRoot := felt.StateRootHash(root) - batch := kvStore.NewBatch() - if newBucket == db.ClassTrie { - require.NoError(t, rawDB.Update(&stateRoot, &zero, 0, mergeSet, nil, batch)) - } else { - require.NoError(t, rawDB.Update(&stateRoot, &zero, 0, nil, mergeSet, batch)) - } - require.NoError(t, batch.Write()) -} - // TestMigrationMultiStorageOwners exercises enumerateStorageTries across // multiple owners (scanTrie's prefix-leave path) and keeps all 4 ingestor // workers busy by giving them 7 tries to chew through (2 global + 5 storage). func TestMigrationMultiStorageOwners(t *testing.T) { - leaves := randomLeaves(50, 7) + leaves := randomLeaves(50) migratedDB := memory.New() buildDeprecatedTrie(t, migratedDB, leaves, trie.NewTriePoseidon, db.ClassesTrie.Key()) @@ -339,8 +243,7 @@ func TestMigrationMultiStorageOwners(t *testing.T) { felt.FromUint64[felt.Address](999), } for _, owner := range owners { - ownerFelt := felt.Felt(owner) - ownerBytes := ownerFelt.Bytes() + ownerBytes := owner.Bytes() buildDeprecatedTrie(t, migratedDB, leaves, trie.NewTriePedersen, db.ContractStorage.Key(ownerBytes[:])) } @@ -354,7 +257,6 @@ func TestMigrationMultiStorageOwners(t *testing.T) { require.NoError(t, err) // Per-owner native build → assert every native key is present (with the - // same value) under the merged migrated view. migratedAll := allKeysUnder(t, migratedDB, db.ContractTrieStorage) for _, owner := range owners { nativeDB := memory.New() @@ -378,7 +280,7 @@ func TestMigrationMultiStorageOwners(t *testing.T) { // a second run: needsMigration sees the wiped deprecated buckets and returns // early without touching the migrated state. func TestMigrationIsNoopOnSecondRun(t *testing.T) { - leaves := randomLeaves(100, 11) + leaves := randomLeaves(100) memDB := buildFullDB(t, leaves) state, err := (&trielib.Migrator{}).Migrate( @@ -405,7 +307,7 @@ func TestMigrationIsNoopOnSecondRun(t *testing.T) { // context.Canceled with the shouldRerun sentinel, and that a fresh ctx // completes the migration normally afterwards. func TestMigrationCancelledContext(t *testing.T) { - leaves := randomLeaves(100, 13) + leaves := randomLeaves(100) memDB := buildFullDB(t, leaves) ctx, cancel := context.WithCancel(context.Background()) @@ -421,6 +323,89 @@ func TestMigrationCancelledContext(t *testing.T) { require.Nil(t, state) } +// buildFullDB creates an old-format DB populated with a class, a contract, and +// one storage trie, all built from the same leaf set. +func buildFullDB(t *testing.T, leaves leafMap) db.KeyValueStore { + t.Helper() + memDB := memory.New() + + owner := felt.FromUint64[felt.Address](42) + ownerBytes := owner.Bytes() + storagePrefix := db.ContractStorage.Key(ownerBytes[:]) + + buildDeprecatedTrie(t, memDB, leaves, trie.NewTriePoseidon, db.ClassesTrie.Key()) + buildDeprecatedTrie(t, memDB, leaves, trie.NewTriePedersen, db.StateTrie.Key()) + buildDeprecatedTrie(t, memDB, leaves, trie.NewTriePedersen, storagePrefix) + + return memDB +} + +func buildDeprecatedTrie( + t *testing.T, + database db.KeyValueStore, + leaves leafMap, + //nolint:staticcheck // Necessary for old state + trieFn func(db.IndexedBatch, []byte, uint8) (*trie.Trie, error), + prefix []byte, +) felt.Felt { + t.Helper() + //nolint:staticcheck // Necessary for old state + txn := database.NewIndexedBatch() + tr, err := trieFn(txn, prefix, 251) + require.NoError(t, err) + for key, value := range leaves { + _, err := tr.Put(&key, &value) + require.NoError(t, err) + } + root, err := tr.Root() + require.NoError(t, err) + require.NoError(t, tr.Commit()) + require.NoError(t, txn.Write()) + return root +} + +func buildTrie( + t *testing.T, + kvStore db.KeyValueStore, + leaves leafMap, + id trieutils.TrieID, + hashFn crypto.HashFn, + newBucket db.Bucket, +) { + t.Helper() + rawDB := rawdb.New(kvStore) + tr, err := trie2.New(id, 251, hashFn, rawDB) + require.NoError(t, err) + for key, value := range leaves { + require.NoError(t, tr.Update(&key, &value)) + } + root, nodes := tr.Commit() + if nodes == nil { + return // empty trie — nothing to persist + } + mergeSet := trienode.NewMergeNodeSet(nodes) + var zero felt.StateRootHash + stateRoot := felt.StateRootHash(root) + batch := kvStore.NewBatch() + if newBucket == db.ClassTrie { + require.NoError(t, rawDB.Update(&stateRoot, &zero, 0, mergeSet, nil, batch)) + } else { + require.NoError(t, rawDB.Update(&stateRoot, &zero, 0, nil, mergeSet, batch)) + } + require.NoError(t, batch.Write()) +} + +func randomLeaves(n int) leafMap { + leaves := make(leafMap, n) + for len(leaves) < n { + var k, v felt.Felt + k.SetRandom() + v.SetRandom() + leaves[k] = v + } + return leaves +} + func snapshotAllBuckets(t *testing.T, r db.KeyValueReader, buckets ...db.Bucket) map[string][]byte { t.Helper() out := make(map[string][]byte) From df52a016e57eb882205e7ba0947a1f2d4ac7e496 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 20 May 2026 23:43:50 +0200 Subject: [PATCH 14/14] chore: linter --- migration/trie/ingestor.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/migration/trie/ingestor.go b/migration/trie/ingestor.go index ad26c39b60..76d85f56e0 100644 --- a/migration/trie/ingestor.go +++ b/migration/trie/ingestor.go @@ -135,8 +135,8 @@ func (i *ingestor) flush(t *task, outputs chan<- task) error { // fields). The new format moves compression into dedicated edge nodes // sitting between binary nodes and their children: // -// deprecated: binary ──────── child-path ────────► child -// new: binary ──► edge ──► child +// old: binary ──────── child-path ────────► child +// new: binary ──► edge ──► child // // The deprecated root marker — a single entry at the bare bucket prefix // recording the root's path — disappears in the new layout. Whatever the